Developer documentation
Version 3.0.3-105-gd3941f44
search.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __registration_transform_search_h__
18#define __registration_transform_search_h__
19
20#include <iostream>
21#include <Eigen/Geometry>
22#include <Eigen/Eigen>
23
24#include "debug.h"
25#include "image.h"
26#include "progressbar.h"
27#include "types.h"
28
29#include "math/math.h"
30#include "math/median.h"
31#include "math/rng.h"
33#include "math/average_space.h"
34#include "filter/resize.h"
35#include "filter/reslice.h"
36#include "adapter/reslice.h"
37#include "interp/linear.h"
38#include "interp/cubic.h"
39#include "interp/nearest.h"
42// #include "registration/metric/local_cross_correlation.h"
46#include "file/config.h"
47
48namespace MR
49{
50 namespace Registration
51 {
52 namespace RotationSearch
53 {
54
56 using MatType = Eigen::Matrix<default_type, 3, 3>;
57 using VecType = Eigen::Matrix<default_type, 3, 1>;
58 using QuatType = Eigen::Quaternion<default_type>;
59
60 template <class MetricType = Registration::Metric::MeanSquaredNoGradient>
62 public:
64 Image<default_type>& image1,
65 Image<default_type>& image2,
68 MetricType& metric_,
69 Registration::Transform::Base& linear_transform,
71 im1 (image1),
72 im2 (image2),
73 mask1 (mask1),
74 mask2 (mask2),
75 metric (metric_),
76 input_trafo (linear_transform),
77 init_options (init),
78 centre (input_trafo.get_centre()),
79 offset (input_trafo.get_translation()),
80 global_search_iterations (init.init_rotation.search.global.iterations),
81 rot_angles (init.init_rotation.search.angles),
82 local_search_directions (init.init_rotation.search.directions),
83 image_scale_factor (init.init_rotation.search.scale),
84 global_search (init.init_rotation.search.run_global),
85 translation_extent (init.init_rotation.search.translation_extent),
86 idx_angle (0),
87 idx_dir (0) {
88 local_trafo.set_centre_without_transform_update (centre);
89 local_trafo.set_translation (offset);
90 Eigen::Matrix<default_type, 3, 3> lin = input_trafo.get_transform().linear();
91 local_trafo.set_matrix_const_translation(lin);
92 INFO ("before search:");
93 INFO (local_trafo.info());
94 };
95
96
100 Header,
103 // use Interp::LinearInterpProcessingType::ValueAndDerivative for metric that calculates gradients
112
113 void write_images (const std::string& im1_path, const std::string& im2_path) {
114 Image<default_type> image1_midway;
115 Image<default_type> image2_midway;
116
117 Header image1_midway_header (midway_image_header);
118 image1_midway_header.datatype() = DataType::Float64;
119 image1_midway_header.ndim() = im1.ndim();
120 for (size_t dim = 3; dim < im1.ndim(); ++dim){
121 image1_midway_header.spacing(dim) = im1.spacing(dim);
122 image1_midway_header.size(dim) = im1.size(dim);
123 }
124 image1_midway = Image<default_type>::create (im1_path, image1_midway_header);
125 Header image2_midway_header (midway_image_header);
126 image2_midway_header.datatype() = DataType::Float64;
127 image2_midway_header.ndim() = im2.ndim();
128 for (size_t dim = 3; dim < im2.ndim(); ++dim){
129 image2_midway_header.spacing(dim) = im2.spacing(dim);
130 image2_midway_header.size(dim) = im2.size(dim);
131 }
132 image2_midway = Image<default_type>::create (im2_path, image2_midway_header);
133
134 Filter::reslice<Interp::Cubic> (im1, image1_midway, local_trafo.get_transform_half(), Adapter::AutoOverSample, 0.0);
135 Filter::reslice<Interp::Cubic> (im2, image2_midway, local_trafo.get_transform_half_inverse(), Adapter::AutoOverSample, 0.0);
136 }
137
138 void run ( bool debug = false ) {
139 std::string what = global_search? "global" : "local";
140 size_t iterations = global_search? global_search_iterations : (rot_angles.size() * local_search_directions);
141 ProgressBar progress ("performing " + what + " search for best rotation", iterations);
142 overlap_it.resize (iterations);
143 cost_it.resize (iterations);
144 trafo_it.reserve (iterations);
145
146 if (!global_search) {
147 gen_uniform_rotation_axes (local_search_directions, 180.0); // full sphere
148 az_el_to_cartesian();
149 }
150
151 size_t iteration (0);
152 ssize_t cnt (0);
153 Eigen::Matrix<default_type, Eigen::Dynamic, 1> gradient (local_trafo.size());
154 Eigen::VectorXd cost = Eigen::VectorXd::Zero(1,1);
156 const Eigen::Translation<default_type, 3> Tc2 (centre - 0.5 * offset), To (offset);
158 R0.translation().fill(0);
159
160 Eigen::Vector3d extent(0,0,0);
161 if (translation_extent != 0) {
162 ParamType parameters = get_parameters ();
163 extent << midway_image_header.spacing(0) * translation_extent * (midway_image_header.size(0) - 0.5),
164 midway_image_header.spacing(1) * translation_extent * (midway_image_header.size(1) - 0.5),
165 midway_image_header.spacing(2) * translation_extent * (midway_image_header.size(2) - 0.5);
166 }
167
168 while ( ++iteration < iterations ) {
169 ++progress;
170 if (iteration > 0) {
171 if (global_search)
172 gen_random_quaternion ();
173 else
174 gen_local_quaternion ();
175
176 R0.linear() = quat.normalized().toRotationMatrix();
177 if (translation_extent != 0) {
178 gen_random_quaternion (); // overwrites quat
179 R0.translation() = rndn () * (quat * extent);
180 DEBUG("translation: " + str(R0.translation().transpose()));
181 }
182
183 T = Tc2 * To * R0 * Tc2.inverse();
184 local_trafo.set_transform<transform_type>(T);
185 }
186
187 ParamType parameters = get_parameters ();
188 // parameters.make_diagnostics_image ("/tmp/debugme"+str(iteration)+".mif", true); // REMOVEME
189 cost.fill(0);
190 cnt = 0;
191 Metric::ThreadKernel<MetricType, ParamType> kernel (metric, parameters, cost, gradient, &cnt);
192 ThreadedLoop (parameters.midway_image, 0, 3).run (kernel);
193 DEBUG ("rotation search: iteration " + str(iteration) + " cost: " + str(cost) + " cnt: " + str(cnt));
194 if (debug)
195 std::cout << str(iteration) + " " + str(cost) + " " + str(cnt) << " " << T.matrix().row(0) << " " << T.matrix().row(1) << " " << T.matrix().row(2) << std::endl;
196 // write_images ( "im1_" + str(iteration) + ".mif", "im2_" + str(iteration) + ".mif");
197 if (cnt == 0) {
198 if (iteration == 0)
199 throw Exception ("zero voxel overlap at initialisation. input matrix wrong?");
200 WARN ("rotation search: overlap count is zero");
201 }
202 overlap_it[iteration] = cnt;
203 cost_it[iteration] = cost(0) / static_cast<default_type>(cnt);
204 trafo_it.push_back (T);
205 }
206 // if (debug) {
207 // save_matrix(cost_it, "/tmp/cost_before.txt");
208 // save_matrix(overlap_it, "/tmp/overlap.txt");
209 // }
210 // best trafo := lowest cost per voxel with at least mean overlap
211 {
212 auto max_ = Eigen::MatrixXd::Constant(cost_it.rows(), 1, std::numeric_limits<default_type>::max());
213 // default_type max_overlap = overlap_it.maxCoeff();
214 default_type mean_overlap = static_cast<default_type>(overlap_it.sum()) / static_cast<default_type>(iterations);
215 // reject solutions with less than mean overlap by setting cost to max
216 cost_it = (overlap_it.array() > mean_overlap).select(cost_it, max_);
217 std::ptrdiff_t i;
218 min_cost = cost_it.minCoeff(&i);
219 T = trafo_it[i];
220 best_trafo = T;
221 }
222 // if (debug) {
223 // save_matrix(cost_it, "/tmp/cost_after.txt");
224 // Eigen::VectorXd t(2);
225 // t(0) = cost_it(0);
226 // t(1) = min_cost;
227 // save_matrix(t, "/tmp/cost_mass_chosen.txt");
228 // save_matrix(centre, "/tmp/centre.txt");
229 // parameters.transformation.set_transform (best_trafo);
230 // write_images ( "/tmp/im1_best.mif", "/tmp/im2_best.mif");
231 // }
232 input_trafo.set_transform<transform_type> (best_trafo);
233 };
234
235 private:
236 FORCE_INLINE ParamType get_parameters () {
237 // create resized midway image
238 // vector<Eigen::Transform<default_type, 3, Eigen::Projective>> init_transforms;
239 // {
240 // Eigen::Transform<default_type, 3, Eigen::Projective> init_trafo_1 = ;
241 // Eigen::Transform<default_type, 3, Eigen::Projective> init_trafo_2 = local_trafo.get_transform_half();
242 // init_transforms.push_back (init_trafo_1);
243 // init_transforms.push_back (init_trafo_2);
244 // }
245 // auto padding = Eigen::Matrix<default_type, 4, 1>(1.0, 1.0, 1.0, 1.0);
246 // int subsample = 1;
247 // vector<Header> headers;
248 // headers.push_back (Header (im1));
249 // headers.push_back (Header (im2));
250 midway_image_header = compute_minimum_average_header (im1, im2, local_trafo.get_transform_half_inverse(), local_trafo.get_transform_half());
251
252 Filter::Resize midway_resize_filter (midway_image_header);
253 midway_resize_filter.set_scale_factor (image_scale_factor);
254 midway_resized_header = Header (midway_resize_filter);
255
256 ParamType parameters (local_trafo, im1, im2, midway_resized_header, mask1, mask2);
257 parameters.loop_density = 1.0;
258 return parameters;
259 }
260
261 // gen_random_quaternion generates random element of SO(3)
262 FORCE_INLINE void gen_random_quaternion () {
263 // Eigen 3.3.0: quat = Eigen::Quaternion<default_type,Eigen::autoalign>::UnitRandom ();
264 // http://planning.cs.uiuc.edu/node198.html
265 const default_type u1 = rnd ();
266 const default_type u2 = rnd () * 2.0 * Math::pi;
267 const default_type u3 = rnd () * 2.0 * Math::pi;
268 assert (u1 < 1.0 && u1 >= 0.0);
269 assert (u2 < 2.0 * Math::pi && u2 >= 0.0);
270 assert (u3 < 2.0 * Math::pi && u3 >= 0.0);
271 const default_type a = std::sqrt(1.0 - u1);
272 const default_type b = std::sqrt(u1);
273 quat = Eigen::Quaternion<default_type> (a * std::sin(u2), a * std::cos(u2), b * std::sin(u3), b * std::cos(u3));
274 }
275
276 // gen_uniform_rotation_axes generates roughly uniformly distributed points on sphere
277 // starting on z-axis up to -z-axis (max_cone_angle_deg=180). points are stored as matrix
278 // of azimuth and elevation
279 // ENH: less brute-force approach (for instance: fixed set of electrostatic repulsion directions, rotate all to gap centre)
280 FORCE_INLINE void gen_uniform_rotation_axes ( const size_t& n_dir, const default_type& max_cone_angle_deg ) {
281 assert (n_dir > 1);
282 assert (max_cone_angle_deg > 0.0);
283 assert (max_cone_angle_deg <= 180.0);
284
285 const default_type golden_ratio ((1.0 + std::sqrt (5.0)) / 2.0);
286 const default_type golden_angle (2.0 * Math::pi * (1.0 - 1.0 / golden_ratio));
287
288 az_el.resize (n_dir,2);
289 Eigen::Matrix<default_type, Eigen::Dynamic, 1> idx (n_dir);
290 for (size_t i = 0; i < n_dir; ++i)
291 idx(i) = i;
292 az_el.col(0) = idx * golden_angle;
293
294 // el(i) = acos (1-(1-cosd(max_cone_angle_deg))*i/(n_dir-1) )
295 default_type a = (1.0 - std::cos(Math::pi / 180.0 * default_type (max_cone_angle_deg))) / (default_type (n_dir - 1));
296 az_el.col(1).array() = - a * idx.array() + 1.0;
297 for (size_t i = 0; i < n_dir; ++i)
298 az_el(i, 1) = std::acos (az_el(i, 1));
299 }
300
301 // convert spherical coordinates (az_el) to cartesian coordinates (xyz)
302 FORCE_INLINE void az_el_to_cartesian () {
303 xyz.resize (az_el.rows(), 3);
304 Eigen::VectorXd el_sin = az_el.col(1).array().sin();
305 xyz.col(0).array() = el_sin.array() * az_el.col(0).array().cos();
306 xyz.col(1).array() = el_sin.array() * az_el.col(0).array().sin();
307 xyz.col(2).array() = az_el.col(1).array().cos();
308 }
309
310 FORCE_INLINE void gen_local_quaternion () {
311 if (idx_dir == local_search_directions) {
312 idx_dir = 0;
313 ++idx_angle;
314 assert (idx_angle < rot_angles.size());
315 }
316 quat = Eigen::Quaternion<default_type> ( Eigen::AngleAxis<default_type> (rot_angles[idx_angle], xyz.row(idx_dir)) );
317 ++idx_dir;
318 }
319
320 Image<default_type> im1, im2, mask1, mask2, midway_image, midway_resized;
321 Header midway_resized_header;
322 MetricType metric;
325 const Eigen::Vector3d centre;
326 const Eigen::Vector3d offset;
329 Eigen::Quaternion<default_type> quat;
330 transform_type best_trafo;
331 Header midway_image_header;
332 default_type min_cost;
333 vector<default_type> vec_cost;
334 vector<size_t> vec_overlap;
335 size_t global_search_iterations;
336 vector<default_type> rot_angles;
337 size_t local_search_directions;
338 default_type image_scale_factor;
339 bool global_search;
340 double translation_extent;
341 size_t idx_angle, idx_dir;
343 Eigen::Matrix<default_type, Eigen::Dynamic, 2> az_el;
344 Eigen::Matrix<default_type, Eigen::Dynamic, 3> xyz;
345 Eigen::Matrix<default_type, Eigen::Dynamic, 1> overlap_it, cost_it;
346 vector<transform_type> trafo_it;
347 };
348 } // namespace RotationSearch
349 }
350}
351
352#endif
static constexpr uint8_t Float64
Definition: datatype.h:148
static Image create(const std::string &image_name, const Header &template_header, bool add_to_command_history=true)
Definition: image.h:192
size_t ndim() const
Definition: image.h:65
default_type spacing(size_t axis) const
Definition: image.h:67
ssize_t size(size_t axis) const
Definition: image.h:66
This class provides access to the voxel intensities of an Image, using nearest-neighbour interpolatio...
Definition: nearest.h:68
implements a progress meter to provide feedback to the user
Definition: progressbar.h:58
#define WARN(msg)
Definition: exception.h:73
#define INFO(msg)
Definition: exception.h:74
#define DEBUG(msg)
Definition: exception.h:75
constexpr double pi
Definition: math.h:40
const vector< uint32_t > AutoOverSample
void init(int argc, const char *const *argv)
initialise MRtrix and parse command-line arguments
@ Value
Definition: linear.h:87
Eigen::Quaternion< default_type > QuatType
Definition: search.h:58
transform_type TrafoType
Definition: search.h:55
Eigen::Matrix< default_type, 3, 1 > VecType
Definition: search.h:57
Eigen::Matrix< default_type, 3, 3 > MatType
Definition: search.h:56
Definition: base.h:24
double default_type
the default type used throughout MRtrix
Definition: types.h:228
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
Eigen::Transform< default_type, 3, Eigen::AffineCompact > transform_type
the type for the affine transform of an image:
Definition: types.h:234
ThreadedLoopRunOuter< decltype(Loop(vector< size_t >()))> ThreadedLoop(const HeaderType &source, const vector< size_t > &outer_axes, const vector< size_t > &inner_axes)
Multi-threaded loop object.
Header compute_minimum_average_header(const vector< Header > &input_headers, const vector< Eigen::Transform< default_type, 3, Eigen::Projective > > &transform_header_with, int voxel_subsampling=1, Eigen::Matrix< default_type, 4, 1 > padding=Eigen::Matrix< default_type, 4, 1 >(1.0, 1.0, 1.0, 1.0))
#define FORCE_INLINE
Definition: types.h:156