Developer documentation
Version 3.0.3-105-gd3941f44
optimal_threshold.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __filter_optimal_threshold_h__
18#define __filter_optimal_threshold_h__
19
20#include "memory.h"
21#include "image.h"
22#include "algo/threaded_loop.h"
23#include "algo/min_max.h"
24#include "adapter/replicate.h"
25#include "filter/base.h"
27
28
29namespace MR
30{
31 namespace Filter
32 {
33
35 namespace {
36
37 class MeanStdFunctor { NOMEMALIGN
38 public:
39 MeanStdFunctor (double& overall_sum, double& overall_sum_sqr, size_t& overall_count) :
40 overall_sum (overall_sum), overall_sum_sqr (overall_sum_sqr), overall_count (overall_count),
41 sum (0.0), sum_sqr (0.0), count (0) { }
42
43 ~MeanStdFunctor () {
44 std::lock_guard<std::mutex> lock (mutex);
45 overall_sum += sum;
46 overall_sum_sqr += sum_sqr;
47 overall_count += count;
48 }
49
50 template <class ImageType, class MaskType>
51 void operator() (ImageType& vox, MaskType& mask) {
52 if (mask.value()) {
53 double in = vox.value();
54 if (std::isfinite(in)) {
55 sum += in;
56 sum_sqr += Math::pow2 (in);
57 ++count;
58 }
59 }
60 }
61
62 template <class ImageType>
63 void operator() (ImageType& vox) {
64 double in = vox.value();
65 if (std::isfinite(in)) {
66 sum += in;
67 sum_sqr += Math::pow2 (in);
68 ++count;
69 }
70 }
71
72 double& overall_sum;
73 double& overall_sum_sqr;
74 size_t& overall_count;
75 double sum, sum_sqr;
76 size_t count;
77
78 static std::mutex mutex;
79 };
80 std::mutex MeanStdFunctor::mutex;
81
82 class CorrelationFunctor { NOMEMALIGN
83 public:
84 CorrelationFunctor (double threshold, double& overall_sum, double& overall_mean_xy) :
85 threshold (threshold), overall_sum (overall_sum), overall_mean_xy (overall_mean_xy),
86 sum (0), mean_xy (0.0) { }
87
88 ~CorrelationFunctor () {
89 std::lock_guard<std::mutex> lock (mutex);
90 overall_sum += sum;
91 overall_mean_xy += mean_xy;
92 }
93
94 template <class ImageType>
95 void operator() (ImageType& vox) {
96 double in = vox.value();
97 if (std::isfinite(in)) {
98 if (in > threshold) {
99 sum += 1;
100 mean_xy += in;
101 }
102 }
103 }
104
105 template <class ImageType, class MaskType>
106 void operator() (ImageType& vox, MaskType& mask) {
107 if (mask.value()) {
108 double in = vox.value();
109 if (std::isfinite(in)) {
110 if (in > threshold) {
111 sum += 1;
112 mean_xy += in;
113 }
114 }
115 }
116 }
117
118 const double threshold;
119 double& overall_sum;
120 double& overall_mean_xy;
121 double sum;
122 double mean_xy;
123
124 static std::mutex mutex;
125 };
126 std::mutex CorrelationFunctor::mutex;
127
128 }
130
131
132 template <class ImageType, class MaskType>
134
135 public:
138
139 ImageCorrelationCostFunction (ImageType& input, MaskType& mask) :
140 input (input),
141 mask (mask)
142 {
143 double sum_sqr = 0.0, sum = 0.0;
144 count = 0;
145
146 if (mask.valid()) {
147 Adapter::Replicate<MaskType> replicated_mask (mask, input);
148 ThreadedLoop (input).run (MeanStdFunctor (sum, sum_sqr, count), input, replicated_mask);
149 }
150 else {
151 ThreadedLoop (input).run (MeanStdFunctor (sum, sum_sqr, count), input);
152 }
153
154 input_image_mean = sum / count;
155 input_image_stdev = sqrt ((sum_sqr - sum * input_image_mean) / count);
156 }
157
159 double sum = 0;
160 double mean_xy = 0.0;
161
162 if (mask.valid()) {
163 Adapter::Replicate<MaskType> replicated_mask (mask, input);
164 ThreadedLoop (input).run (CorrelationFunctor (threshold, sum, mean_xy), input, replicated_mask);
165 }
166 else
167 ThreadedLoop (input).run (CorrelationFunctor (threshold, sum, mean_xy), input);
168
169 mean_xy /= count;
170 double covariance = mean_xy - (sum / count) * input_image_mean;
171 double mask_stdev = sqrt ((sum - double (sum * sum) / count) / count);
172
173 return -covariance / (input_image_stdev * mask_stdev);
174 }
175
176 private:
177 ImageType& input;
178 MaskType& mask;
179 size_t count;
180 double input_image_mean;
181 double input_image_stdev;
182 };
183
184
185 template <class ImageType, class MaskType>
186 typename ImageType::value_type estimate_optimal_threshold (ImageType& input, MaskType& mask)
187 {
188 using input_value_type = typename ImageType::value_type;
189
190 input_value_type min, max;
191 if (mask.valid())
192 min_max (input, mask, min, max);
193 else
194 min_max (input, min, max);
195
196 input_value_type optimal_threshold = 0.0;
197 {
198 ImageCorrelationCostFunction<ImageType, MaskType> cost_function (input, mask);
199 optimal_threshold = Math::golden_section_search (cost_function, "optimising threshold",
200 min + 0.001*(max-min), 0.5*(min+max), max-0.001*(max-min));
201 }
202
203 return optimal_threshold;
204 }
205
206
207
208
209 template <class ImageType>
210 inline typename ImageType::value_type estimate_optimal_threshold (ImageType& input)
211 {
212 Image<bool> mask;
213 return estimate_optimal_threshold (input, mask);
214 }
215
220
239 class OptimalThreshold : public Base { MEMALIGN(OptimalThreshold)
240 public:
241 OptimalThreshold (const Header& H) :
242 Base (H)
243 {
245 }
246
247 template <class InputImageType, class OutputImageType>
248 void operator() (InputImageType& input, OutputImageType& output)
249 {
250 Image<bool> mask;
251 operator() (input, output, mask);
252 }
253
254
255 template <class InputImageType, class OutputImageType, class MaskType>
256 void operator() (InputImageType& input, OutputImageType& output, MaskType& mask)
257 {
258 using input_value_type = typename InputImageType::value_type;
259
260 input_value_type optimal_threshold = estimate_optimal_threshold (input, mask);
261
262 auto f = [&](decltype(input) in, decltype(output) out) {
263 input_value_type val = in.value();
264 out.value() = ( std::isfinite (val) && val > optimal_threshold ) ? 1 : 0;
265 };
266 ThreadedLoop ("thresholding", input) .run (f, input, output);
267 }
268 };
270 }
271}
272
273
274
275
276#endif
static constexpr uint8_t Bit
Definition: datatype.h:142
typename ImageType::value_type value_type
ImageCorrelationCostFunction(ImageType &input, MaskType &mask)
typename MaskType::value_type mask_value_type
value_type operator()(value_type threshold) const
a filter to compute the optimal threshold to mask a DataSet.
DataType datatype_
the type of the data as stored on file
Definition: header.h:370
ValueType golden_section_search(FunctionType &function, const std::string &message, ValueType min_bound, ValueType init_estimate, ValueType max_bound, ValueType tolerance=0.01)
Computes the minimum of a 1D function using a golden section search.
constexpr T pow2(const T &v)
Definition: math.h:53
#define NOMEMALIGN
Definition: memory.h:22
ImageType::value_type estimate_optimal_threshold(ImageType &input, MaskType &mask)
MR::default_type value_type
Definition: typedefs.h:33
Definition: base.h:24
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
void min_max(ImageType &in, typename ImageType::value_type &min, typename ImageType::value_type &max, size_t from_axis=0, size_t to_axis=std::numeric_limits< size_t >::max())
Definition: min_max.h:75
Eigen::MatrixXd H