Developer documentation
Version 3.0.3-105-gd3941f44
iFOD2.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __dwi_tractography_algorithms_iFOD2_h__
18#define __dwi_tractography_algorithms_iFOD2_h__
19
20#include <algorithm>
21
22#include "types.h"
23#include "math/SH.h"
30
31
32namespace MR
33{
34 namespace DWI
35 {
36 namespace Tractography
37 {
38 namespace Algorithms
39 {
40
41 extern const App::OptionGroup iFOD2Options;
43
44 using namespace MR::DWI::Tractography::Tracking;
45
46 class iFOD2 : public MethodBase { MEMALIGN(iFOD2)
47 public:
48
49 class Shared : public SharedBase { MEMALIGN(Shared)
50 public:
51 Shared (const std::string& diff_path, DWI::Tractography::Properties& property_set) :
52 SharedBase (diff_path, property_set),
53 lmax (Math::SH::LforN (source.size(3))),
54 num_samples (Defaults::ifod2_nsamples),
56 sin_max_angle_ho (NaN),
57 mean_samples (0.0),
58 mean_truncations (0.0),
59 max_max_truncation (0.0),
60 num_proc (0)
61 {
62 try {
63 Math::SH::check (source);
64 } catch (Exception& e) {
65 e.display();
66 throw Exception ("Algorithm iFOD2 expects as input a spherical harmonic (SH) image");
67 }
68
69 if (rk4)
70 throw Exception ("4th-order Runge-Kutta integration not valid for iFOD2 algorithm");
71
73 sin_max_angle_ho = std::sin (max_angle_ho);
74 set_cutoff (Defaults::cutoff_fod * (is_act() ? Defaults::cutoff_act_multiplier : 1.0));
75
76 properties["method"] = "iFOD2";
77 properties.set (lmax, "lmax");
78 properties.set (num_samples, "samples_per_step");
79 properties.set (max_trials, "max_trials");
80 fod_power = 1.0/num_samples;
81 properties.set (fod_power, "fod_power");
82 bool precomputed = true;
83 properties.set (precomputed, "sh_precomputed");
84 if (precomputed)
85 precomputer.init (lmax);
86
87 // num_samples is number of samples excluding first point
88 --num_samples;
89 INFO ("iFOD2 generating " + str(num_samples) + " vertices per " + str (step_size) + " mm step");
90
91 // iFOD2 by default downsamples after track propagation back to the desired 'step size'
92 // i.e. the sub-step detail is removed from the output
93 size_t downsample_factor = num_samples;
94 properties.set (downsample_factor, "downsample_factor");
95 downsampler.set_ratio (downsample_factor);
96
97 // For iFOD2, "step_size" represents the length of the chord represented
98 // using "num_samples" vertices rather than just one; the following two
99 // variables need to be calculated accordingly:
100 // - The arc angle subtended by two sequential vertices on a circle of minimal radius
101 // (prior to downsampling)
102 const float angle_minradius_preds = 2.0 * std::asin (step_size / (2.0 * min_radius)) / float(num_samples);
103 // - The maximal possible distance between vertices after downsampling
104 const float max_step_postds = downsample_factor * step_size / float(num_samples);
105 set_num_points (angle_minradius_preds, max_step_postds);
106 }
107
108 ~Shared ()
109 {
110 mean_samples /= double(num_proc);
111 mean_truncations /= double(num_proc);
112 INFO ("mean number of samples per step = " + str (mean_samples));
113 if (mean_truncations) {
114 INFO ("mean number of steps between rejection sampling truncations = " + str (1.0/mean_truncations));
115 INFO ("maximum truncation error = " + str (max_max_truncation));
116 } else {
117 INFO ("no rejection sampling truncations occurred");
118 }
119 }
120
121 void update_stats (double mean_samples_per_run, double mean_truncations_per_run, double max_truncation) const
122 {
123 mean_samples += mean_samples_per_run;
124 mean_truncations += mean_truncations_per_run;
125 if (max_truncation > max_max_truncation)
126 max_max_truncation = max_truncation;
127 ++num_proc;
128 }
129
130 float internal_step_size() const override { return step_size / float(num_samples); }
131
132 size_t lmax, num_samples, max_trials;
133 float sin_max_angle_ho, fod_power;
135
136 private:
137 mutable double mean_samples, mean_truncations, max_max_truncation;
138 mutable int num_proc;
139 };
140
141
142
143
144
145
146
147
148
149 iFOD2 (const Shared& shared) :
150 MethodBase (shared),
151 S (shared),
152 source (S.source),
153 mean_sample_num (0),
154 num_sample_runs (0),
155 num_truncations (0),
156 max_truncation (0.0),
157 positions (S.num_samples),
158 calib_positions (S.num_samples),
159 tangents (S.num_samples),
160 calib_tangents (S.num_samples),
161 sample_idx (S.num_samples)
162 {
163 calibrate (*this);
164 }
165
166 iFOD2 (const iFOD2& that) :
167 MethodBase (that.S),
168 S (that.S),
169 source (S.source),
170 calibrate_ratio (that.calibrate_ratio),
171 mean_sample_num (0),
172 num_sample_runs (0),
173 num_truncations (0),
174 max_truncation (0.0),
175 calibrate_list (that.calibrate_list),
176 positions (S.num_samples),
177 calib_positions (S.num_samples),
178 tangents (S.num_samples),
179 calib_tangents (S.num_samples),
180 sample_idx (S.num_samples)
181 {
182 }
183
184
185
186 ~iFOD2 ()
187 {
188 if (num_sample_runs)
189 S.update_stats (calibrate_list.size() + float(mean_sample_num)/float(num_sample_runs),
190 float(num_truncations) / float(num_sample_runs),
191 max_truncation);
192 }
193
194
195
196
197 bool init() override
198 {
199 if (!get_data (source))
200 return false;
201
202 if (!S.init_dir.allFinite()) {
203
204 const Eigen::Vector3f init_dir (dir);
205
206 for (size_t n = 0; n < S.max_seed_attempts; n++) {
207 dir = init_dir.allFinite() ? rand_dir (init_dir) : random_direction();
208 half_log_prob0 = FOD (dir);
209 if (std::isfinite (half_log_prob0) && (half_log_prob0 > S.init_threshold))
210 goto end_init;
211 }
212
213 } else {
214
215 dir = S.init_dir;
216 half_log_prob0 = FOD (dir);
217 if (std::isfinite (half_log_prob0) && (half_log_prob0 > S.init_threshold))
218 goto end_init;
219
220 }
221
222 return false;
223
224end_init:
225 half_log_prob0_seed = half_log_prob0 = 0.5 * std::log (half_log_prob0);
226 sample_idx = S.num_samples; // Force arc calculation on first iteration
227 return true;
228 }
229
230
231
232 term_t next () override
233 {
234
235 if (++sample_idx < S.num_samples) {
236 pos = positions[sample_idx];
237 dir = tangents [sample_idx];
238 return CONTINUE;
239 }
240
241 Eigen::Vector3f next_pos, next_dir;
242
243 float max_val = 0.0;
244 for (size_t i = 0; i < calibrate_list.size(); ++i) {
245 get_path (calib_positions, calib_tangents, rotate_direction (dir, calibrate_list[i]));
246 float val = path_prob (calib_positions, calib_tangents);
247 if (std::isnan (val))
248 return EXIT_IMAGE;
249 else if (val > max_val)
250 max_val = val;
251 }
252
253 if (max_val <= 0.0)
254 return CALIBRATOR;
255
256 max_val *= calibrate_ratio;
257
258 num_sample_runs++;
259
260 for (size_t n = 0; n < S.max_trials; n++) {
261 float val = rand_path_prob ();
262
263 if (val > max_val) {
264 DEBUG ("max_val exceeded!!! (val = " + str(val) + ", max_val = " + str (max_val) + ")");
265 ++num_truncations;
266 if (val/max_val > max_truncation)
267 max_truncation = val/max_val;
268 }
269
270 if (uniform(rng) < val/max_val) {
271 mean_sample_num += n;
272 half_log_prob0 = last_half_log_probN;
273 pos = positions[0];
274 dir = tangents [0];
275 sample_idx = 0;
276 return CONTINUE;
277 }
278 }
279
280 return MODEL;
281 }
282
283
284 float get_metric (const Eigen::Vector3f& position, const Eigen::Vector3f& direction) override
285 {
286 if (!get_data (source, position))
287 return 0.0;
288 return FOD (direction);
289 }
290
291
292 // Restore proper probability from the FOD at the track seed point
293 void reverse_track() override
294 {
295 half_log_prob0 = half_log_prob0_seed;
296 sample_idx = S.num_samples;
297 MethodBase::reverse_track();
298 }
299
300
301 void truncate_track (GeneratedTrack& tck, const size_t length_to_revert_from, const size_t revert_step) override
302 {
303 // OK, if we know length_to_revert_from, we can reconstruct what sample_idx was at that point
304 size_t sample_idx_at_full_length = (length_to_revert_from - tck.get_seed_index()) % S.num_samples;
305 // Unfortunately can't distinguish between sample_idx = 0 and sample_idx = S.num_samples
306 // However the former would result in zero truncation with revert_step = 1...
307 if (!sample_idx_at_full_length)
308 sample_idx_at_full_length = S.num_samples;
309 const size_t points_to_remove = sample_idx_at_full_length + ((revert_step - 1) * S.num_samples);
310 if (tck.get_seed_index() + points_to_remove >= tck.size()) {
311 tck.clear();
312 pos = { NaN, NaN, NaN };
313 dir = { NaN, NaN, NaN };
314 return;
315 }
316 const size_t new_size = length_to_revert_from - points_to_remove;
317 if (tck.size() == 2 || new_size == 1)
318 dir = (tck[1] - tck[0]).normalized();
319 else if (new_size != tck.size())
320 dir = (tck[new_size] - tck[new_size - 2]).normalized();
321 tck.resize (new_size);
322
323 // Need to get the path probability contribution from the FOD at this point
324 pos = tck.back();
325 get_data (source);
326 half_log_prob0 = 0.5 * std::log (FOD (dir));
327
328 // Make sure that arc is re-calculated when next() is called
329 sample_idx = S.num_samples;
330
331 // Need to update sgm_depth appropriately, remembering that it is tracked by exec
332 if (S.is_act())
333 act().sgm_depth = (act().sgm_depth > points_to_remove) ? act().sgm_depth - points_to_remove : 0;
334 }
335
336
337
338 private:
339 const Shared& S;
340 Interpolator<Image<float>>::type source;
341 float calibrate_ratio, half_log_prob0, last_half_log_probN, half_log_prob0_seed;
342 size_t mean_sample_num, num_sample_runs, num_truncations;
343 float max_truncation;
344 vector<Eigen::Vector3f> calibrate_list;
345
346 // Store list of points in the currently-calculated arc
347 vector<Eigen::Vector3f> positions, calib_positions;
348 vector<Eigen::Vector3f> tangents, calib_tangents;
349
350 // Generate an arc only when required, and on the majority of next() calls, simply return the next point
351 // in the arc - more dense structural image sampling
352 size_t sample_idx;
353
354
355
356 FORCE_INLINE float FOD (const Eigen::Vector3f& direction) const
357 {
358 return (S.precomputer ?
359 S.precomputer.value (values, direction) :
360 Math::SH::value (values, direction, S.lmax)
361 );
362 }
363
364 FORCE_INLINE float FOD (const Eigen::Vector3f& position, const Eigen::Vector3f& direction)
365 {
366 if (!get_data (source, position))
367 return NaN;
368 return FOD (direction);
369 }
370
371
372
373
374 FORCE_INLINE float rand_path_prob ()
375 {
376 get_path (positions, tangents, rand_dir (dir));
377 return path_prob (positions, tangents);
378 }
379
380
381
382 float path_prob (vector<Eigen::Vector3f>& positions, vector<Eigen::Vector3f>& tangents)
383 {
384
385 // Early exit for ACT when path is not sensible
386 if (S.is_act()) {
387 if (!act().fetch_tissue_data (positions[S.num_samples - 1]))
388 return (NaN);
389 if (act().tissues().get_csf() >= 0.5)
390 return 0.0;
391 }
392
393 float log_prob = half_log_prob0;
394 for (size_t i = 0; i < S.num_samples; ++i) {
395
396 float fod_amp = FOD (positions[i], tangents[i]);
397 if (std::isnan (fod_amp))
398 return NaN;
399 if (fod_amp < S.threshold)
400 return 0.0;
401 fod_amp = std::log (fod_amp);
402 if (i < S.num_samples-1) {
403 log_prob += fod_amp;
404 } else {
405 last_half_log_probN = 0.5*fod_amp;
406 log_prob += last_half_log_probN;
407 }
408 }
409
410 return std::exp (S.fod_power * log_prob);
411 }
412
413
414 protected:
415 void get_path (vector<Eigen::Vector3f>& positions, vector<Eigen::Vector3f>& tangents, const Eigen::Vector3f& end_dir) const
416 {
417 float cos_theta = end_dir.dot (dir);
418 cos_theta = std::min (cos_theta, float(1.0));
419 float theta = std::acos (cos_theta);
420
421 if (theta) {
422
423 Eigen::Vector3f curv = end_dir - cos_theta * dir;
424 curv.normalize();
425 float R = S.step_size / theta;
426
427 for (size_t i = 0; i < S.num_samples-1; ++i) {
428 float a = (theta * (i+1)) / S.num_samples;
429 float cos_a = std::cos (a);
430 float sin_a = std::sin (a);
431 positions[i] = pos + R * (sin_a * dir + (float(1.0) - cos_a) * curv);
432 tangents[i] = cos_a * dir + sin_a * curv;
433 }
434 positions[S.num_samples-1] = pos + R * (std::sin (theta) * dir + (float(1.0)-cos_theta) * curv);
435 tangents[S.num_samples-1] = end_dir;
436
437 } else { // straight on:
438
439 for (size_t i = 0; i < S.num_samples; ++i) {
440 float f = (i+1) * (S.step_size / S.num_samples);
441 positions[i] = pos + f * dir;
442 tangents[i] = dir;
443 }
444
445 }
446 }
447
448
449
450 FORCE_INLINE Eigen::Vector3f rand_dir (const Eigen::Vector3f& d) { return (random_direction (d, S.max_angle_ho, S.sin_max_angle_ho)); }
451
452
453
454 private:
455 class Calibrate
456 { MEMALIGN(Calibrate)
457 public:
458 Calibrate (iFOD2& method) :
459 P (method),
460 fod (P.values),
461 vox (P.S.vox()),
462 positions (P.S.num_samples),
463 tangents (P.S.num_samples) {
464 Math::SH::delta (fod, Eigen::Vector3f (0.0, 0.0, 1.0), P.S.lmax);
465 init_log_prob = 0.5 * std::log (Math::SH::value (P.values, Eigen::Vector3f (0.0, 0.0, 1.0), P.S.lmax));
466 }
467
468 float operator() (float el)
469 {
470 P.pos = { 0.0f, 0.0f, 0.0f };
471 P.get_path (positions, tangents, Eigen::Vector3f (std::sin (el), 0.0, std::cos(el)));
472
473 float log_prob = init_log_prob;
474 for (size_t i = 0; i < P.S.num_samples; ++i) {
475 float prob = Math::SH::value (P.values, tangents[i], P.S.lmax) * (1.0 - (positions[i][0] / vox));
476 if (prob <= 0.0)
477 return 0.0;
478 prob = std::log (prob);
479 if (i < P.S.num_samples-1)
480 log_prob += prob;
481 else
482 log_prob += 0.5*prob;
483 }
484
485 return std::exp (P.S.fod_power * log_prob);
486 }
487
488 private:
489 iFOD2& P;
490 Eigen::VectorXf& fod;
491 const float vox;
492 float init_log_prob;
493 vector<Eigen::Vector3f> positions, tangents;
494 };
495
496 friend void calibrate<iFOD2> (iFOD2& method);
497
498 };
499
500
501
502 }
503 }
504 }
505}
506
507#endif
508
float el
Definition: calibrator.h:61
friend void calibrate(iFOD2 &method)
FORCE_INLINE Eigen::Vector3f rand_dir(const Eigen::Vector3f &d)
Definition: iFOD2.h:450
void get_path(vector< Eigen::Vector3f > &positions, vector< Eigen::Vector3f > &tangents, const Eigen::Vector3f &end_dir) const
Definition: iFOD2.h:415
std::uniform_real_distribution< float > uniform
Definition: method.h:95
Eigen::Vector3f rotate_direction(const Eigen::Vector3f &reference, const Eigen::Vector3f &direction)
Precomputed Associated Legrendre Polynomials - used to speed up SH calculation.
Definition: SH.h:400
#define INFO(msg)
Definition: exception.h:74
#define DEBUG(msg)
Definition: exception.h:75
constexpr double e
Definition: math.h:39
void check(const ImageType &H)
convenience function to check if an input image can contain SH coefficients
Definition: SH.h:745
size_t LforN(int N)
returns the largest lmax given N parameters
Definition: SH.h:70
VectorType1 & delta(VectorType1 &delta_vec, const VectorType2 &unit_dir, int lmax)
Definition: SH.h:279
VectorType::Scalar value(const VectorType &coefs, typename VectorType::Scalar cos_elevation, typename VectorType::Scalar cos_azimuth, typename VectorType::Scalar sin_azimuth, int lmax)
Definition: SH.h:233
void load_iFOD2_options(Tractography::Properties &)
const App::OptionGroup iFOD2Options
thread_local Math::RNG rng
thread-local, but globally accessible RNG to vastly simplify multi-threading
Definition: base.h:24
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
constexpr default_type NaN
Definition: types.h:230
Eigen::MatrixXd S
Eigen::MatrixXd R
#define MEMALIGN(...)
Definition: types.h:185
#define FORCE_INLINE
Definition: types.h:156