Developer documentation
Version 3.0.3-105-gd3941f44
local_cross_correlation.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __image_registration_metric_local_cross_correlation_h__
18#define __image_registration_metric_local_cross_correlation_h__
19
20#include "transform.h"
21#include "algo/loop.h"
22#include "algo/threaded_loop.h"
23#include "adapter/reslice.h"
24#include "filter/reslice.h"
26
27namespace MR
28{
29 namespace Registration
30 {
31 namespace Metric
32 {
33 template <typename ImageType1, typename ImageType2>
35 template <typename MaskType, typename ImageType3>
36 void operator() (MaskType& mask, ImageType3& out) {
37 if (!mask.value())
38 return;
39 Eigen::Vector3d pos (mask.index(0), mask.index(1), mask.index(2));
40 out.index(0) = pos[0];
41 out.index(1) = pos[1];
42 out.index(2) = pos[2];
43 out.index(3) = 0;
44
45 int nmax = extent[0] * extent[1] * extent[2];
46 Eigen::VectorXd n1 = Eigen::VectorXd(nmax);
47 Eigen::VectorXd n2 = Eigen::VectorXd(nmax);
48
49 using value_type = typename ImageType3::value_type;
50 in1.index(0) = pos[0];
51 in1.index(1) = pos[1];
52 in1.index(2) = pos[2];
53 value_type value_in1 = in1.value();
54 if (value_in1 != value_in1){ // nan in image 1, update mask
55 mask.value() = false;
56 out.row(3) = 0.0;
57 return;
58 }
59 in2.index(0) = pos[0];
60 in2.index(1) = pos[1];
61 in2.index(2) = pos[2];
62 value_type value_in2 = in2.value();
63 if (value_in2 != value_in2){ // nan in image 2, update mask
64 mask.value() = false;
65 out.row(3) = 0.0;
66 return;
67 }
68
69 n1.setZero();
70 n2.setZero();
71 auto niter = NeighbourhoodIterator(mask, extent);
72 size_t cnt = 0;
73 while(niter.loop()) {
74 mask.index(0) = niter.index(0);
75 mask.index(1) = niter.index(1);
76 mask.index(2) = niter.index(2);
77 if (!mask.value())
78 continue;
79 in1.index(0) = niter.index(0);
80 in1.index(1) = niter.index(1);
81 in1.index(2) = niter.index(2);
82 value_type val1 = in1.value();
83 if (val1 != val1){
84 continue;
85 }
86 in2.index(0) = niter.index(0);
87 in2.index(1) = niter.index(1);
88 in2.index(2) = niter.index(2);
89 value_type val2 = in2.value();
90 if (val2 != val2){
91 continue;
92 }
93
94 n1[cnt] = in1.value();
95 n2[cnt] = in2.value();
96
97 cnt++;
98 }
99 // reset the mask index
100 mask.index(0) = out.index(0);
101 mask.index(1) = out.index(1);
102 mask.index(2) = out.index(2);
103
104 if (cnt <= 0)
105 throw Exception ("FIXME: neighbourhood does not contain centre");
106
107 // local mean subtracted
108 default_type m1 = n1.sum() / cnt;
109 default_type m2 = n2.sum() / cnt;
110 n1.array() -= m1;
111 n2.array() -= m2;
112 out.row(3) = ( Eigen::Matrix<default_type,5,1>() << value_in1 - m1, value_in2 - m2, n1.adjoint() * n2, n1.adjoint() * n1, n2.adjoint() * n2 ).finished();
113 }
114
115 LCCPrecomputeFunctorMasked_Naive (const vector<size_t>& ext, ImageType1& adapter1, ImageType2& adapter2) :
116 extent(ext),
117 in1(adapter1),
118 in2(adapter2) { /* TODO check dimensions and extent */ }
119
120 protected:
122 ImageType1 in1; // store reslice adapter in functor to avoid iterating over it when mask is false
123 ImageType2 in2; // TODO: cache interpolated values for neighbourhood iteration
124 };
125
127 private:
128 transform_type midway_v2s;
129
130 public:
132 using is_neighbourhood = int;
135
136 void set_weights (Eigen::Matrix<default_type, Eigen::Dynamic, 1> weights) {
137 assert ("FIXME: set_weights not implemented");
138 }
139
140 template <class ParamType>
141 default_type precompute(ParamType& parameters) {
142 INFO ("precomputing cross correlation data...");
143
144 using Im1Type = decltype(parameters.im1_image);
145 using Im2Type = decltype(parameters.im2_image);
146 using Im1MaskType = decltype(parameters.im1_mask);
147 using Im2MaskType = decltype(parameters.im2_mask);
148 using ProcessedImageValueType = typename ParamType::ProcessedValueType;
149 using ProcessedMaskType = typename ParamType::ProcessedMaskType;
150 using ProcessedMaskInterpolatorType = typename ParamType::ProcessedMaskInterpType;
151 using CCInterpType = typename ParamType::ProcessedImageInterpType;
152
153 Header midway_header (parameters.midway_image);
154 midway_v2s = MR::Transform (midway_header).voxel2scanner;
155
156 // store precomputed values in cc_image:
157 // volumes 0 and 1: normalised intensities of both images (Im1 and Im2)
158 // volumes 2 to 4: neighbourhood dot products Im1.dot(Im2), Im1.dot(Im1), Im2.dot(Im2)
159 auto cc_image_header = Header::scratch (midway_header);
160 cc_image_header.ndim() = 4;
161 cc_image_header.size(3) = 5;
162 ProcessedMaskType cc_mask;
163 auto cc_mask_header = Header::scratch (parameters.midway_image);
164
165 auto cc_image = cc_image_header.template get_image <ProcessedImageValueType>().with_direct_io(Stride::contiguous_along_axis(3));
166 vector<uint32_t> NoOversample;
167 {
169 if (parameters.im1_mask.valid() or parameters.im2_mask.valid())
170 cc_mask = cc_mask_header.template get_image<bool>();
171 if (parameters.im1_mask.valid() and !parameters.im2_mask.valid())
172 Filter::reslice <Interp::Nearest> (parameters.im1_mask, cc_mask, parameters.transformation.get_transform_half());
173 else if (!parameters.im1_mask.valid() and parameters.im2_mask.valid())
174 Filter::reslice <Interp::Nearest> (parameters.im2_mask, cc_mask, parameters.transformation.get_transform_half_inverse(), Adapter::AutoOverSample);
175 else if (parameters.im1_mask.valid() and parameters.im2_mask.valid()){
176 Adapter::Reslice<Interp::Nearest, Im1MaskType> mask_reslicer1 (parameters.im1_mask, cc_mask_header, parameters.transformation.get_transform_half(), NoOversample);
177 Adapter::Reslice<Interp::Nearest, Im2MaskType> mask_reslicer2 (parameters.im2_mask, cc_mask_header, parameters.transformation.get_transform_half_inverse(), NoOversample);
178 // TODO should be faster to just loop over m1:
179 // if (m1.value())
180 // assign_pos_of(m1).to(m2); cc_mask.value() = m2.value()
181 auto both = [](decltype(cc_mask)& cc_mask, decltype(mask_reslicer1)& m1, decltype(mask_reslicer2)& m2) {
182 cc_mask.value() = ((m1.value() + m2.value()) / 2.0) > 0.5 ? true : false;
183 };
184 ThreadedLoop (cc_mask).run (both, cc_mask, mask_reslicer1, mask_reslicer2);
185 }
186 }
187 Adapter::Reslice<Interp::Linear, Im1Type> interp1 (parameters.im1_image, midway_header, parameters.transformation.get_transform_half(), NoOversample, std::numeric_limits<typename Im1Type::value_type>::quiet_NaN());
188
189 Adapter::Reslice<Interp::Linear, Im2Type> interp2 (parameters.im2_image, midway_header, parameters.transformation.get_transform_half_inverse(), NoOversample, std::numeric_limits<typename Im2Type::value_type>::quiet_NaN());
190
191 const auto extent = parameters.get_extent();
192
193 // TODO unmasked CCPrecomputeFunctor.
194 // ThreadedLoop (cc_image, 0, 3).run (CCPrecomputeFunctor_Bogus(), interp1, interp2, cc_image);
195 // create a mask (all voxels true) if none given.
196 if (!cc_mask.valid()){
197 cc_mask = cc_mask_header.template get_image<bool>();
198 ThreadedLoop (cc_mask).run([](decltype(cc_mask)& m) {m.value()=true;}, cc_mask);
199 }
200 parameters.processed_mask = cc_mask;
201 parameters.processed_mask_interp.reset (new ProcessedMaskInterpolatorType (parameters.processed_mask));
202 auto loop = ThreadedLoop ("precomputing cross correlation data...", parameters.processed_mask);
203 loop.run (LCCPrecomputeFunctorMasked_Naive<decltype(interp1), decltype(interp2)>(extent, interp1, interp2), parameters.processed_mask, cc_image);
204 parameters.processed_image = cc_image;
205 parameters.processed_image_interp.reset (new CCInterpType (parameters.processed_image));
206 // display<Image<default_type>>(parameters.processed_image);
207 return 0.0;
208 }
209
210 template <class Params>
212 const Iterator& iter,
213 Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient) {
214 // iterates over processed image rather than midway image
215
216 if (params.processed_mask.valid()) {
217 assign_pos_of(iter, 0, 3).to(params.processed_mask);
218 if (!params.processed_mask.value())
219 return 0.0;
220 }
221
222 assign_pos_of(iter).to(params.processed_image); // TODO why do we need this?
223 assert (params.processed_image.index(0) == iter.index(0));
224 assert (params.processed_image.index(1) == iter.index(1));
225 assert (params.processed_image.index(2) == iter.index(2));
226
227 params.processed_image.index(3) = 2;
228 default_type A = params.processed_image.value();
229 params.processed_image.index(3) = 3;
230 default_type B = params.processed_image.value();
231 params.processed_image.index(3) = 4;
232 default_type C = params.processed_image.value();
233 default_type A_BC = A / (B * C);
234 params.processed_image.index(3) = 0;
235
236 if (A_BC != A_BC || A_BC == 0.0) {
237 return 0.0;
238 }
239
240 const Eigen::Vector3d pos = Eigen::Vector3d(default_type(iter.index(0)), default_type(iter.index(0)), default_type(iter.index(0)));
241 params.processed_image_interp->voxel(pos);
242 typename Params::Im1ValueType val1;
243 typename Params::Im2ValueType val2;
244 Eigen::Matrix<typename Params::Im1ValueType, 1, 3> grad1;
245 Eigen::Matrix<typename Params::Im2ValueType, 1, 3> grad2;
246
247 // gradient calculation
248 params.processed_image_interp->index(3) = 0;
249 params.processed_image_interp->value_and_gradient_wrt_scanner(val1, grad1);
250
251 if (val1 != val1){
252 // this should not happen as the precompute should have changed the mask
253 WARN ("FIXME: val1 is nan");
254 return 0.0;
255 }
256
257 params.processed_image_interp->index(3) = 1;
258 params.processed_image_interp->value_and_gradient_wrt_scanner(val2, grad2);
259 if (val2 != val2){
260 // this should not happen as the precompute should have changed the mask
261 WARN ("FIXME: val2 is nan");
262 return 0.0;
263 }
264
265 // 2.0 * sfm / (sff * smm) * ((i2 - sfm / smm * i1 ) * im1_gradient.value() - (i1 - sfm / sff * i2 ) * im2_gradient.value());
266
267 // ITK:
268 // derivWRTImage[dim] = 2.0 * sFixedMoving / (sFixedFixed_sMovingMoving) * (fixedI - sFixedMoving / sMovingMoving * movingI) * movingImageGradient[dim];
269 Eigen::Vector3d derivWRTImage = - A_BC * ((val2 - A/B * val1) * grad1 - 0.0 * (val1 - A/C * val2) * grad2);
270
271 const Eigen::Vector3d midway_point = midway_v2s * pos;
272 const auto jacobian_vec = params.transformation.get_jacobian_vector_wrt_params (midway_point);
273 gradient.segment<4>(0) += derivWRTImage(0) * jacobian_vec;
274 gradient.segment<4>(4) += derivWRTImage(1) * jacobian_vec;
275 gradient.segment<4>(8) += derivWRTImage(2) * jacobian_vec;
276
277 return A * A_BC;
278 }
279 };
280 }
281 }
282}
283#endif
an Image providing interpolated values from another Image
Definition: reslice.h:112
a dummy image to iterate over, useful for multi-threaded looping.
Definition: iterator.h:29
const ssize_t & index(size_t axis) const
Definition: iterator.h:43
a dummy image to iterate over a certain neighbourhood, useful for multi-threaded looping.
default_type operator()(Params &params, const Iterator &iter, Eigen::Matrix< default_type, Eigen::Dynamic, 1 > &gradient)
void set_weights(Eigen::Matrix< default_type, Eigen::Dynamic, 1 > weights)
const transform_type voxel2scanner
Definition: transform.h:43
#define WARN(msg)
Definition: exception.h:73
#define INFO(msg)
Definition: exception.h:74
const vector< uint32_t > AutoOverSample
int log_level
Definition: exception.h:34
MR::default_type value_type
Definition: typedefs.h:33
List contiguous_along_axis(size_t axis)
convenience function to get volume-contiguous strides
Definition: stride.h:386
Definition: base.h:24
double default_type
the default type used throughout MRtrix
Definition: types.h:228
Eigen::Transform< default_type, 3, Eigen::AffineCompact > transform_type
the type for the affine transform of an image:
Definition: types.h:234
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
#define MEMALIGN(...)
Definition: types.h:185
LCCPrecomputeFunctorMasked_Naive(const vector< size_t > &ext, ImageType1 &adapter1, ImageType2 &adapter2)
MEMALIGN(LCCPrecomputeFunctorMasked_Naive< ImageType1, ImageType2 >) template< typename MaskType
ImageType3 void operator()(MaskType &mask, ImageType3 &out)