Developer documentation
Version 3.0.3-105-gd3941f44
zclean.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17
18#ifndef __filter_zclean_h__
19#define __filter_zclean_h__
20
21#include "progressbar.h"
22#include "memory.h"
23#include "image.h"
24#include "algo/copy.h"
25#include "algo/loop.h"
26#include "filter/base.h"
27#include "filter/erode.h"
28#include "filter/dilate.h"
30#include "math/median.h"
31
32
33namespace MR
34{
35 namespace Filter
36 {
37
38
39
40 class ZClean : public Base { MEMALIGN(ZClean)
41
42 public:
43 template <class HeaderType>
44 ZClean (const HeaderType& in) :
45 Base (in),
46 zupper (2.5),
47 zlower (2.5),
48 fov_max (0.3),
49 fov_min (0.15),
50 bridge (0),
51 dont_maskupper (false),
52 keep_lower (false),
53 keep_upper (true)
54 {
56 ndim() = 3;
57 }
58
59 template <class HeaderType>
60 ZClean (const HeaderType& in, const std::string& message) :
61 Base (in, message),
62 zupper (2.5),
63 zlower (2.5),
64 fov_max (0.3),
65 fov_min (0.15),
66 bridge (0),
67 dont_maskupper (false),
68 keep_lower (false),
69 keep_upper (true)
70 {
72 ndim() = 3;
73 }
74
75 template <class InputImageType, class MaskType, class OutputImageType>
76 void operator() (InputImageType& input, MaskType& spatial_prior, OutputImageType& output)
77 {
78 if (output.ndim() > 3)
79 throw Exception ("3D output expected");
80
81 std::unique_ptr<ProgressBar> progress (message.size() ? new ProgressBar (message) : nullptr);
82
83 Image<bool> int_roi = Image<bool>::scratch (Header(spatial_prior), "temporary initial mask");
84 INFO ("creating intensity mask from input mask");
85 Dilate dilation_filter (spatial_prior);
86 dilation_filter.set_npass(2);
87 dilation_filter (spatial_prior, int_roi);
88 ssize_t cnt = 0;
89 for (auto l = Loop (0,3) (int_roi); l; ++l) {
90 cnt += int_roi.value();
91 }
92 ssize_t cnt_lower = std::max<size_t>(10000, std::floor(fov_min * input.size(0) * input.size(1) * input.size(2)));
93 ssize_t cnt_upper = std::floor(fov_max * input.size(0) * input.size(1) * input.size(2));
94 float mad, median, previous_mad, previous_median;
95 calculate_median_mad<Image<float>, Image<bool>> (input, int_roi, cnt, median, mad);
96 INFO ("median: "+str(median));
97 INFO ("mad: "+str(mad));
98 INFO ("lower: " + str(median - zlower * mad) + " upper: " + str(median + zupper * mad));
99
100 INFO ("eroding intensity mask");
101 while (cnt >= cnt_lower) {
102 if (progress)
103 ++(*progress);
104 Erode erosion_filter (int_roi);
105 erosion_filter.set_npass(1);
106 erosion_filter (int_roi, int_roi);
107 cnt = 0;
108 for (auto l = Loop (0,3) (int_roi); l; ++l)
109 cnt += int_roi.value();
110 if (cnt == 0)
111 throw Exception ("mask empty after erosion");
112 previous_median = median;
113 previous_mad = mad;
114 calculate_median_mad<Image<float>, Image<bool>> (input, int_roi, cnt, median, mad);
115 upper = median + zupper * mad;
116 lower = median - zlower * mad;
117 INFO ("median: " + str(median) + ", changed: "+str((median - previous_median) / previous_median));
118 INFO ("mad: " + str(mad) + ", changed: "+str((mad - previous_mad) / previous_mad));
119 INFO ("FOV: " + str(float(cnt) / (input.size(0) * input.size(1) * input.size(2))));
120 INFO ("lower: " + str(lower) + " upper: " + str(upper));
121 INFO ("cnt_upper - cnt: " + str(cnt_upper - cnt));
122 if (lower > 0.0 && ((median + 2.5 * mad) - (previous_median + 2.5 * previous_mad)) < 0.0 && (cnt < cnt_upper))
123 break;
124 }
125
126 if (App::log_level >= 3) {
127 auto masked_image = Image<float>::scratch (input, "robust z score");
128 // output;
129 for (auto l = Loop (0,3) (masked_image, input, int_roi); l; ++l) {
130 if (int_roi.value())
131 masked_image.value() = input.value();
132 }
133 display (masked_image);
134 }
135
136 {
137 INFO ("intensity sample mask");
138 if (progress)
139 ++(*progress);
140
141 Image<float> eroded_zscore_image;
142 if (App::log_level >= 3) {
143 eroded_zscore_image = Image<float>::scratch (input, "robust z score");
144 }
145
146 int maxiter = 5;
147 while (maxiter--) {
148 // refine image mask based on robust Z score
149 cnt = 0;
150 for (auto l = Loop (0,3) (input, int_roi); l; ++l) {
151 if (int_roi.value()) {
152 float z = (input.value() - median) / mad;
153 bool good = (z > -zlower) && (z < zupper);
154 if (App::log_level >= 3) {
155 assign_pos_of(input, 0, 3).to(eroded_zscore_image);
156 eroded_zscore_image.value() = (z > -zlower) && (dont_maskupper || z < zupper) ? z : NaN;
157 }
158 if (good) cnt++;
159 int_roi.value() = good;
160 } else if (App::log_level >= 3) {
161 assign_pos_of(input, 0, 3).to(eroded_zscore_image);
162 eroded_zscore_image.value() = NaN;
163 }
164 }
165 previous_mad = mad;
166 previous_median = median;
167 calculate_median_mad<Image<float>, Image<bool>> (input, int_roi, cnt, median, mad);
168 upper = median + zupper * mad;
169 lower = median - zlower * mad;
170 INFO("median: " + str(median) + ", changed: " + str((median - previous_median)));
171 INFO("mad: " + str(mad) + ", changed: " + str((mad - previous_mad)));
172 INFO("lower: " + str(lower) + " upper: " + str(upper));
173 float change = MR::abs(median - previous_median) / previous_mad;
174 INFO("convergence: "+str(change));
175 if (change < 1e-2)
176 break;
177 }
178 if (App::log_level >= 3)
179 display (eroded_zscore_image);
180 }
181 upper = median + zupper * mad;
182 lower = median - zlower * mad;
183 if (lower < 0.0) {
184 WARN ("likely not converged, setting lower to 0.0");
185 lower = 0.0;
186 }
187
188
189 INFO ("lower: "+str(lower));
190 INFO ("upper: "+str(upper));
191 INFO ("bridge: "+str(bridge));
192
193 mask = Image<bool>::scratch (Header(spatial_prior), "temporary mask");
194 if (progress)
195 ++(*progress);
196
197 for (auto l = Loop (0,3) (input, mask, spatial_prior); l; ++l)
198 mask.value() = spatial_prior.value() && input.value() >= lower && (dont_maskupper || input.value() <= upper);
199
200 if (App::log_level >= 3)
201 display (mask);
202 if (progress)
203 ++(*progress);
204
205 {
206 INFO ("selecting largest ROI");
207 ConnectedComponents connected_filter (mask);
208 connected_filter.set_largest_only (true);
209 connected_filter (mask, mask);
210 if (progress)
211 ++(*progress);
212 }
213
214 for (auto l = Loop (0,3) (mask); l; ++l)
215 mask.value() = !mask.value();
216
217 {
218 INFO ("removing masked out islands");
219 ConnectedComponents connected_filter (mask);
220 connected_filter.set_largest_only (true);
221 connected_filter (mask, mask);
222 if (progress)
223 ++(*progress);
224 }
225
226 if (bridge) {
227 INFO ("bridging");
228 for (auto l = Loop (0,3) (mask); l; ++l)
229 mask.value() = !mask.value();
230 if (progress)
231 ++(*progress);
232 Dilate dilation_filter (mask);
233 dilation_filter.set_npass(bridge);
234 dilation_filter (mask, mask);
235 if (progress)
236 ++(*progress);
237 for (auto l = Loop (0,3) (mask); l; ++l)
238 mask.value() = !mask.value();
239 if (progress)
240 ++(*progress);
241 ConnectedComponents connected_filter (mask);
242 connected_filter.set_largest_only (true);
243 connected_filter (mask, mask);
244 if (progress)
245 ++(*progress);
246 Dilate dilation_filter2 (mask);
247 dilation_filter2.set_npass(bridge);
248 dilation_filter2 (mask, mask);
249 if (progress)
250 ++(*progress);
251 if (App::log_level >= 3)
252 display (mask);
253 }
254
255 for (auto l = Loop (0,3) (mask, spatial_prior); l; ++l)
256 mask.value() = !mask.value() && spatial_prior.value();
257 if (progress)
258 ++(*progress);
259
260 float lo = std::max<float>(median - 2.5 * mad, lower);
261 float hi = std::min<float>(median + 2.5 * mad, upper);
262 for (auto l = Loop (0,3) (input, spatial_prior, mask, output); l; ++l) {
263 if (!spatial_prior.value())
264 continue;
265 float val = input.value();
266 if (mask.value()) {
267 if (val < lo)
268 output.value() = val; // hack
269 else if (val > hi)
270 output.value() = hi;
271 else
272 output.value() = val;
273 continue;
274 } else { // outside refined mask but inside initial mask
275 if (keep_lower && val < lo)
276 output.value() = lo;
277 else if (keep_upper && val > hi)
278 output.value() = hi;
279 }
280 }
281
282 }
283
284 void set_zlim (float upper, float lower)
285 {
286 zupper = upper;
287 zlower = lower;
288 }
289
290 void set_voxels_to_bridge (size_t nvoxels)
291 {
292 bridge = nvoxels;
293 }
294
295 Image<bool> mask;
296
297 protected:
300 size_t bridge;
302 float upper, lower;
303
304 template <typename ImageType, typename MaskType>
305 void calculate_median_mad (ImageType& image, MaskType& mask, size_t nvoxels, float& median, float& mad) {
306 MR::vector<float> vals (nvoxels);
307 size_t idx = 0;
308 for (auto l = Loop (0,3) (mask, image); l; ++l) {
309 if (mask.value())
310 vals[idx++] = image.value();
311 assert (idx <= nvoxels);
312 }
313 median = Math::median(vals);
314 for (auto & v : vals)
315 v = MR::abs(v - median);
316 mad = Math::median(vals);
317 }
318
319 };
321 }
322}
323
324
325
326
327#endif
static constexpr uint8_t Float32
Definition: datatype.h:147
std::string message
Definition: base.h:66
a filter to dilate a mask
Definition: dilate.h:49
a filter to erode a mask
Definition: erode.h:49
bool dont_maskupper
Definition: zclean.h:301
void calculate_median_mad(ImageType &image, MaskType &mask, size_t nvoxels, float &median, float &mad)
Definition: zclean.h:305
DataType datatype_
the type of the data as stored on file
Definition: header.h:370
static Image scratch(const Header &template_header, const std::string &label="scratch image")
Definition: image.h:195
implements a progress meter to provide feedback to the user
Definition: progressbar.h:58
#define WARN(msg)
Definition: exception.h:73
#define INFO(msg)
Definition: exception.h:74
constexpr I floor(const T x)
template function with cast to different type
Definition: math.h:75
FORCE_INLINE LoopAlongAxes Loop()
Definition: loop.h:419
constexpr double e
Definition: math.h:39
int log_level
Definition: exception.h:34
Container::value_type median(Container &list)
Definition: median.h:45
Definition: base.h:24
enable_if_image_type< ImageType, void >::type display(ImageType &x)
display the contents of an image in MRView (for debugging only)
Definition: image.h:541
constexpr std::enable_if< std::is_arithmetic< X >::value &&std::is_unsigned< X >::value, X >::type abs(X x)
Definition: types.h:297
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
constexpr default_type NaN
Definition: types.h:230