Developer documentation
Version 3.0.3-105-gd3941f44
evaluate.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __registration_metric_evaluate_h__
18#define __registration_metric_evaluate_h__
19
21#include "algo/threaded_loop.h"
22#include "algo/loop.h"
24#include "image.h"
25
26namespace MR
27{
28 namespace Registration
29 {
30 namespace Metric
31 {
33 namespace {
34 template<class T>
35 struct Void2 { NOMEMALIGN
36 using type = void;
37 };
38
39 template <class MetricType, typename U = void>
40 struct metric_requires_precompute { NOMEMALIGN
41 using no = int;
42 };
43
44 template <class MetricType>
45 struct metric_requires_precompute<MetricType, typename Void2<typename MetricType::requires_precompute>::type> { NOMEMALIGN
46 using yes = int;
47 };
48
49 template <class MetricType, typename U = void>
50 struct metric_requires_initialisation { NOMEMALIGN
51 using no = int;
52 };
53
54 template <class MetricType>
55 struct metric_requires_initialisation<MetricType, typename Void2<typename MetricType::requires_initialisation>::type> { NOMEMALIGN
56 using yes = int;
57 };
58 }
60
61 template <class MetricType, class ParamType>
63 public:
64
65 using TransformParamType = typename ParamType::TransformParamType;
66 using value_type = default_type;
67
68 template <class U = MetricType>
69 Evaluate () = delete;
70
71 template <class U = MetricType>
72 Evaluate (const MetricType& metric_, ParamType& parameters, typename metric_requires_initialisation<U>::yes = 0) :
73 metric (metric_),
74 params (parameters),
75 iteration (1) {
76 // update number of volumes
77 metric.init (parameters.im1_image, parameters.im2_image);
78 metric.set_weights(params.get_weights());
79 }
80
81 template <class U = MetricType>
82 Evaluate (const MetricType& metric_, ParamType& parameters, typename metric_requires_initialisation<U>::no = 0) :
83 metric (metric_),
84 params (parameters),
85 iteration (1) { metric.set_weights(params.get_weights()); }
86
87 // metric_requires_precompute<U>::yes: operator() loops over processed_image instead of midway_image
88 template <class U = MetricType>
89 default_type operator() (const Eigen::Matrix<default_type, Eigen::Dynamic, 1>& x, Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient, typename metric_requires_precompute<U>::yes = 0) {
90 Eigen::VectorXd overall_cost_function = Eigen::VectorXd::Zero(1,1);
91
92 gradient.setZero();
93 params.transformation.set_parameter_vector(x);
94
95 if (directions.cols()) {
96 DEBUG ("Reorienting FODs...");
97 std::shared_ptr<Image<default_type> > im1_image_reoriented;
98 std::shared_ptr<Image<default_type> > im2_image_reoriented;
99 im1_image_reoriented = make_shared<Image<default_type>>(Image<default_type>::scratch (params.im1_image));
100 im2_image_reoriented = make_shared<Image<default_type>>(Image<default_type>::scratch (params.im2_image));
101
102 {
103 if (params.mc_settings.size()) {
104 DEBUG ("Tissue contrast specific FOD reorientation");
105 Registration::Transform::reorient (params.im1_image, *im1_image_reoriented, params.transformation.get_transform_half(), directions, false, params.mc_settings);
106 Registration::Transform::reorient (params.im2_image, *im2_image_reoriented, params.transformation.get_transform_half_inverse(), directions, false, params.mc_settings);
107 } else {
108 DEBUG ("FOD reorientation");
109 Registration::Transform::reorient (params.im1_image, *im1_image_reoriented, params.transformation.get_transform_half(), directions);
110 Registration::Transform::reorient (params.im2_image, *im2_image_reoriented, params.transformation.get_transform_half_inverse(), directions);
111 }
112 }
113
114 params.set_im1_iterpolator (*im1_image_reoriented);
115 params.set_im2_iterpolator (*im2_image_reoriented);
116 }
117
118 metric.precompute (params);
119 {
120 overlap_count = 0;
121 ThreadKernel<MetricType, ParamType> kernel (metric, params, overall_cost_function, gradient, &overlap_count);
122 {
124 ThreadedLoop (params.processed_image, 0, 3).run (kernel);
125 }
126 }
127 DEBUG ("Metric evaluate iteration: " + str(iteration++) + ", cost: " + str(overall_cost_function.transpose()));
128 DEBUG (" x: " + str(x.transpose()));
129 DEBUG (" gradient: " + str(gradient.transpose()));
130 DEBUG (" norm(gradient): " + str(gradient.norm()));
131 DEBUG (" overlapping voxels: " + str(overlap_count));
132 return overall_cost_function(0);
133 }
134
135 // template <class TransformType_>
136 // void estimate (TransformType_&& trafo,
137 // const MetricType& metric,
138 // const ParamType& params,
139 // Eigen::VectorXd& cost,
140 // Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient,
141 // const Eigen::Matrix<default_type, Eigen::Dynamic, 1>& x,
142 // ssize_t* overlap_count = nullptr) {
143
144 // if (params.loop_density < 1.0) {
145 // DEBUG ("stochastic gradient descent, density: " + str(params.loop_density));
146 // Math::RNG rng;
147 // gradient.setZero();
148 // auto loop = ThreadedLoop (params.midway_image, 0, 3, 2);
149 // if (overlap_count)
150 // *overlap_count = 0;
151 // StochasticThreadKernel <MetricType, ParamType> functor (loop.inner_axes, params.loop_density, metric, params, cost, gradient, rng, overlap_count);
152 // loop.run_outer (functor);
153 // }
154 // else {
155 // if (overlap_count)
156 // *overlap_count = 0;
157 // ThreadKernel <MetricType, ParamType> kernel (metric, params, cost, gradient, overlap_count);
158
159 // if (params.robust_estimate_subset) {
160 // assert(params.robust_estimate_subset_from.size());
161 // assert(params.robust_estimate_subset_size.size());
162 // Adapter::Subset<Image<default_type>> midway_subset (params.midway_image, params.robust_estimate_subset_from, params.robust_estimate_subset_size);
163 // ThreadedLoop (midway_subset, 0, 3).run (kernel);
164 // } else {
165 // ThreadedLoop (params.midway_image, 0, 3).run (kernel);
166 // }
167 // }
168 // }
169
170 template <class U = MetricType>
171 default_type operator() (const Eigen::Matrix<default_type, Eigen::Dynamic, 1>& x, Eigen::Matrix<default_type, Eigen::Dynamic, 1>& gradient, typename metric_requires_precompute<U>::no = 0) {
172 Eigen::VectorXd overall_cost_function = Eigen::VectorXd::Zero(1,1);
173 gradient.setZero();
174 params.transformation.set_parameter_vector(x);
175
176 if (directions.cols()) {
177 DEBUG ("Reorienting FODs...");
178 std::shared_ptr<Image<default_type> > im1_image_reoriented;
179 std::shared_ptr<Image<default_type> > im2_image_reoriented;
180 im1_image_reoriented = make_shared<Image<default_type>>(Image<default_type>::scratch (params.im1_image));
181 im2_image_reoriented = make_shared<Image<default_type>>(Image<default_type>::scratch (params.im2_image));
182
183 {
184 if (params.mc_settings.size()) {
185 DEBUG ("Tissue contrast specific FOD reorientation");
186 Registration::Transform::reorient (params.im1_image, *im1_image_reoriented, params.transformation.get_transform_half(), directions, false, params.mc_settings);
187 Registration::Transform::reorient (params.im2_image, *im2_image_reoriented, params.transformation.get_transform_half_inverse(), directions, false, params.mc_settings);
188 } else {
189 DEBUG ("FOD reorientation");
190 Registration::Transform::reorient (params.im1_image, *im1_image_reoriented, params.transformation.get_transform_half(), directions);
191 Registration::Transform::reorient (params.im2_image, *im2_image_reoriented, params.transformation.get_transform_half_inverse(), directions);
192 }
193 }
194
195 params.set_im1_iterpolator (*im1_image_reoriented);
196 params.set_im2_iterpolator (*im2_image_reoriented);
197 }
198
199 // estimate (params.transformation, metric, params, overall_cost_function, gradient, x, &overlap_count);
200 if (params.loop_density < 1.0) {
201 DEBUG ("stochastic gradient descent, density: " + str(params.loop_density));
203 gradient.setZero();
204 auto loop = ThreadedLoop (params.midway_image, 0, 3, 2);
205 overlap_count = 0;
206 StochasticThreadKernel <MetricType, ParamType> functor (loop.inner_axes, params.loop_density, metric, params, overall_cost_function, gradient, rng, &overlap_count);
207 {
209 loop.run_outer (functor);
210 }
211 } else {
212 overlap_count = 0;
213 ThreadKernel <MetricType, ParamType> kernel (metric, params, overall_cost_function, gradient, &overlap_count);
214 if (params.robust_estimate_subset) {
215 assert(params.robust_estimate_subset_from.size() == 3);
216 assert(params.robust_estimate_subset_size.size() == 3);
217 Adapter::Subset<decltype(params.processed_mask)> subset (params.processed_mask, params.robust_estimate_subset_from, params.robust_estimate_subset_size);
219 // single threaded as we loop over small VOIs. multi-threading of small VOIs is VERY slow compared to single threading!
220 for (auto i = Loop(0,3) (subset); i; ++i) {
221 kernel(subset);
222 }
223 } else {
225 ThreadedLoop (params.midway_image, 0, 3).run (kernel);
226 }
227 }
228
229 DEBUG ("Metric evaluate iteration: " + str(iteration++) + ", cost: " + str(overall_cost_function.transpose()));
230 DEBUG (" x: " + str(x.transpose()));
231 DEBUG (" gradient: " + str(gradient.transpose()));
232 DEBUG (" norm(gradient): " + str(gradient.norm()));
233 DEBUG (" overlapping voxels: " + str(overlap_count));
234 return overall_cost_function(0);
235 }
236
237 size_t size() {
238 return params.transformation.size();
239 }
240
241 ssize_t overlap() {
242 return overlap_count;
243 }
244
245 default_type init (Eigen::VectorXd& x) {
246 params.transformation.get_parameter_vector(x);
247 return 1.0;
248 }
249
250 void set_directions (const Eigen::MatrixXd& dir) {
251 directions = dir;
252 }
253
254 protected:
255 MetricType metric;
256 ParamType params;
258 size_t iteration;
259 Eigen::MatrixXd directions;
261
262 };
263 }
264 }
265}
266
267#endif
static Image scratch(const Header &template_header, const std::string &label="scratch image")
Definition: image.h:195
random number generator
Definition: rng.h:45
#define DEBUG(msg)
Definition: exception.h:75
FORCE_INLINE LoopAlongAxes Loop()
Definition: loop.h:419
#define NOMEMALIGN
Definition: memory.h:22
int log_level
Definition: exception.h:34
thread_local Math::RNG rng
thread-local, but globally accessible RNG to vastly simplify multi-threading
void reorient(FODImageType &input_fod_image, FODImageType &output_fod_image, const transform_type &transform, const Eigen::MatrixXd &directions, bool modulate=false, vector< MultiContrastSetting > multi_contrast_settings=vector< MultiContrastSetting >())
Definition: reorient.h:153
Definition: base.h:24
double default_type
the default type used throughout MRtrix
Definition: types.h:228
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
#define MEMALIGN(...)
Definition: types.h:185
std::remove_reference< Functor >::type & functor
Definition: thread.h:215