Developer documentation
Version 3.0.3-105-gd3941f44
cross_correlation.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __registration_metric_cross_correlation_h__
18#define __registration_metric_cross_correlation_h__
19
20#include "transform.h"
21#include "interp/linear.h"
23
24namespace MR
25{
26 namespace Registration
27 {
28 namespace Metric
29 {
31
32 public:
35 using requires_precompute = int;
36
37 template <class Params>
38 default_type operator() (Params& params,
39 const Iterator& iter,
40 Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient) {
41
42 // const Eigen::Vector3d pos = Eigen::Vector3d (iter.index(0), iter.index(1), iter.index(2));
43
44 assert (params.processed_mask.valid());
45 assert (params.processed_image.valid());
46 assert (!this->weighted && "FIXME: set_weights not implemented for CrossCorrelationNoGradient metric");
47
48 assign_pos_of (iter, 0, 3). to (params.processed_mask);
49 if (!params.processed_mask.value())
50 return 0.0;
51
52 default_type val1 = params.processed_image.value();
53 ++params.processed_image.index(3);
54 default_type val2 = params.processed_image.value();
55 --params.processed_image.index(3);
56
57 return (mean1 - val1) * (val2 - mean2); // negative cross correlation
58 }
59
60 template <class ParamType>
61 default_type precompute (ParamType& parameters) {
62 DEBUG ("precomputing cross correlation data...");
63
64 using Im1Type = decltype(parameters.im1_image);
65 using Im2Type = decltype(parameters.im2_image);
66 using MidwayImageType = decltype(parameters.midway_image);
67 using Im1MaskType = decltype(parameters.im1_mask);
68 using Im2MaskType = decltype(parameters.im2_mask);
69 using Im1ImageInterpolatorType = typename ParamType::Im1InterpType;
70 using Im2ImageInterpolatorType = typename ParamType::Im2InterpType;
71 using PImageType = typename ParamType::ProcessedImageType;
72 // using Im1MaskInterpolatorType = Interp::LinearInterp<Im1MaskType, Interp::LinearInterpProcessingType::Value>;
73 // using Im2MaskInterpolatorType = Interp::LinearInterp<Im2MaskType, Interp::LinearInterpProcessingType::Value>;
74 using Im1MaskInterpolatorType = typename ParamType::Mask1InterpolatorType;
75 using Im2MaskInterpolatorType = typename ParamType::Mask2InterpolatorType;
76
77 assert (parameters.midway_image.ndim() == 3);
78 mean1 = 0.0;
79 mean2 = 0.0;
80 size_t overlap (0);
81
82 Header midway_header (parameters.midway_image);
83 MR::Transform transform (midway_header);
84
85 parameters.processed_mask = Header::scratch (midway_header).template get_image<bool>();
86 // processed_image: 2 volumes: interpolated image1 value, interpolated image2 value if both masks' values are >= 0.5
87 auto cc_header = Header::scratch (parameters.midway_image);
88 cc_header.ndim() = 4;
89 cc_header.size(3) = 2;
90 // PImageType cc_image = PImageType::scratch (cc_header);
91 parameters.processed_image = PImageType::scratch (cc_header);
92
93 auto loop = ThreadedLoop ("precomputing cross correlation data...", parameters.processed_image, 0, 3);
94 loop.run (CCNoGradientPrecomputeFunctor<decltype(parameters.transformation),
95 Im1Type,
96 Im2Type,
97 MidwayImageType,
98 Im1MaskType,
99 Im2MaskType,
100 Im1ImageInterpolatorType,
101 Im2ImageInterpolatorType,
102 Im1MaskInterpolatorType,
103 Im2MaskInterpolatorType> (
104 parameters.transformation,
105 parameters.im1_image,
106 parameters.im2_image,
107 parameters.midway_image,
108 parameters.im1_mask,
109 parameters.im2_mask,
110 mean1,
111 mean2,
112 overlap), parameters.processed_image, parameters.processed_mask);
113 // display<Im1Type>(parameters.im1_image);
114 // display<Im2Type>(parameters.im2_image);
115 // display<Image<bool>>(parameters.processed_mask);
116 // VAR(overlap);
117 // VAR(mean1);
118 // VAR(mean2);
119 if (overlap > 0 ) {
120 mean1 /= static_cast<default_type>(overlap);
121 mean2 /= static_cast<default_type>(overlap);
122 } else {
123 DEBUG ("Cross Correlation metric: zero overlap");
124 }
125
126 return 0;
127 }
128
129 private:
130 default_type mean1;
131 default_type mean2;
132 // default_type denom; // TODO: denominator for normalisation
133
134 template <
135 typename LinearTrafoType,
136 typename ImageType1,
137 typename ImageType2,
138 typename MidwayImageType,
139 typename MaskType1,
140 typename MaskType2,
141 typename Im1ImageInterpolatorType,
142 typename Im2ImageInterpolatorType,
143 typename Im1MaskInterpolatorType,
144 typename Im2MaskInterpolatorType
145 >
146 struct CCNoGradientPrecomputeFunctor { MEMALIGN(CCNoGradientPrecomputeFunctor)
147 CCNoGradientPrecomputeFunctor (
148 const LinearTrafoType& transformation,
149 ImageType1& im1,
150 ImageType2& im2,
151 const MidwayImageType& midway,
152 MaskType1& mask1,
153 MaskType2& mask2,
154 default_type& sum_im1,
155 default_type& sum_im2,
156 size_t& overlap):
157 trafo_half (transformation.get_transform_half()),
158 trafo_half_inverse (transformation.get_transform_half_inverse()),
159 v2s (MR::Transform(midway).voxel2scanner),
160 in1 (im1),
161 in2 (im2),
162 msk1 (mask1),
163 msk2 (mask2),
164 global_s1 (sum_im1),
165 global_s2 (sum_im2),
166 global_cnt (overlap),
167 s1 (0.0),
168 s2 (0.0),
169 cnt (0) {
170 assert (in1.valid());
171 assert (in2.valid());
172 im1_image_interp.reset (new Im1ImageInterpolatorType (in1));
173 im2_image_interp.reset (new Im2ImageInterpolatorType (in2));
174 if (msk1.valid())
175 im1_mask_interp.reset (new Im1MaskInterpolatorType (msk1));
176 if (msk2.valid())
177 im2_mask_interp.reset (new Im2MaskInterpolatorType (msk2));
178 }
179
180 ~CCNoGradientPrecomputeFunctor () {
181 global_s1 += s1;
182 global_s2 += s2;
183 global_cnt += cnt;
184 }
185
186
187 template <typename ProcessedImageType, typename MaskImageType>
188 void operator() (ProcessedImageType& pimage, MaskImageType& mask) {
189 assert(mask.index(0) == pimage.index(0));
190 assert(mask.index(1) == pimage.index(1));
191 assert(mask.index(2) == pimage.index(2));
192 assert(pimage.index(3) == 0);
193 vox = Eigen::Vector3d (default_type(pimage.index(0)), default_type(pimage.index(1)), default_type(pimage.index(2)));
194 pos = v2s * vox;
195
196 pos1 = trafo_half * pos;
197 if (msk1.valid()) {
198 im1_mask_interp->scanner(pos1);
199 if (!(*im1_mask_interp))
200 return;
201 if (im1_mask_interp->value() < 0.5)
202 return;
203 }
204
205 pos2 = trafo_half_inverse * pos;
206 if (msk2.valid()) {
207 im2_mask_interp->scanner(pos2);
208 if (!(*im2_mask_interp))
209 return;
210 if (im2_mask_interp->value() < 0.5)
211 return;
212 }
213
214 im1_image_interp->scanner(pos1);
215 if (!(*im1_image_interp))
216 return;
217 v1 = im1_image_interp->value();
218 if (v1 != v1)
219 return;
220
221 im2_image_interp->scanner(pos2);
222 if (!(*im2_image_interp))
223 return;
224 v2 = im2_image_interp->value();
225 if (v2 != v2)
226 return;
227
228 mask.value() = 1;
229 s1 += v1;
230 s2 += v2;
231
232 pimage.value() = v1;
233 ++pimage.index(3);
234 pimage.value() = v2;
235 --pimage.index(3);
236 ++cnt;
237 }
238
239 private:
240 const Eigen::Transform<default_type, 3, Eigen::AffineCompact> trafo_half, trafo_half_inverse;
241 const transform_type v2s;
242 ImageType1 in1;
243 ImageType2 in2;
244 MaskType1 msk1;
245 MaskType2 msk2;
246 default_type &global_s1, &global_s2;
247 size_t& global_cnt;
248 default_type s1, s2;
249 size_t cnt;
250 Eigen::Vector3d vox, pos, pos1, pos2;
251 default_type v1, v2;
256 };
257 };
258 }
259 }
260}
261#endif
a dummy image to iterate over, useful for multi-threaded looping.
Definition: iterator.h:29
#define DEBUG(msg)
Definition: exception.h:75
Definition: base.h:24
double default_type
the default type used throughout MRtrix
Definition: types.h:228
Eigen::Transform< default_type, 3, Eigen::AffineCompact > transform_type
the type for the affine transform of an image:
Definition: types.h:234
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
T to(const std::string &string)
Definition: mrtrix.h:260
#define MEMALIGN(...)
Definition: types.h:185