Developer documentation
Version 3.0.3-105-gd3941f44
gradient.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __image_filter_gradient_h__
18#define __image_filter_gradient_h__
19
20#include "memory.h"
21#include "image.h"
22#include "transform.h"
23#include "algo/loop.h"
24#include "algo/threaded_copy.h"
25#include "adapter/gradient1D.h"
26#include "filter/base.h"
27#include "filter/smooth.h"
28
29namespace MR
30{
31 namespace Filter
32 {
46 class Gradient : public Base { MEMALIGN(Gradient)
47 public:
48 template <class HeaderType>
49 Gradient (const HeaderType& in, const bool magnitude = false) :
50 Base (in),
51 smoother (in),
52 wrt_scanner (true),
54 stdev (1, 0)
55 {
56 if (in.ndim() == 4) {
57 if (!magnitude) {
58 axes_.resize (5);
59 axes_[3].size = 3;
60 axes_[4].size = in.size(3);
61 axes_[0].stride = 2;
62 axes_[1].stride = 3;
63 axes_[2].stride = 4;
64 axes_[3].stride = 1;
65 axes_[4].stride = 5;
66 }
67 } else if (in.ndim() == 3) {
68 if (!magnitude) {
69 axes_.resize (4);
70 axes_[3].size = 3;
71 axes_[0].stride = 2;
72 axes_[1].stride = 3;
73 axes_[2].stride = 4;
74 axes_[3].stride = 1;
75 }
76 } else {
77 throw Exception("input image must be 3D or 4D");
78 }
79 datatype() = DataType::Float32;
80 DEBUG ("creating gradient filter");
81 }
82
83 void compute_wrt_scanner (bool do_wrt_scanner) {
84 wrt_scanner = do_wrt_scanner;
85 }
86
87 void set_stdev (const vector<default_type>& stdevs) {
88 stdev = stdevs;
89 }
90
91
92 template <class InputImageType, class OutputImageType>
93 void operator() (InputImageType& in, OutputImageType& out)
94 {
95 if (magnitude) {
96 Gradient full_gradient (in, false);
97 full_gradient.set_stdev (stdev);
98 full_gradient.compute_wrt_scanner (wrt_scanner);
99 full_gradient.set_message (message);
100 auto temp = Image<float>::scratch (full_gradient, "full 3D gradient image");
101 full_gradient (in, temp);
102 for (auto l = Loop (out)(out, temp); l; ++l) {
103 if (out.ndim() == 4) {
104 ssize_t tmp = out.index(3);
105 temp.index(4) = tmp;
106 }
107 float grad_sq = 0.0;
108 for (temp.index(3) = 0; temp.index(3) != 3; ++temp.index(3))
109 grad_sq += Math::pow2<float> (temp.value());
110 out.value() = std::sqrt (grad_sq);
111 }
112 return;
113 }
114 smoother.set_stdev (stdev);
115 auto smoothed = Image<float>::scratch (smoother);
116 if (message.size())
117 smoother.set_message ("applying smoothing prior to calculating gradient");
118 threaded_copy (in, smoothed);
119 smoother (smoothed);
120
121 const size_t num_volumes = (in.ndim() == 3) ? 1 : in.size(3);
122
123 std::unique_ptr<ProgressBar> progress (message.size() ? new ProgressBar (message, 3 * num_volumes) : nullptr);
124
125 for (size_t vol = 0; vol < num_volumes; ++vol) {
126 if (in.ndim() == 4) {
127 smoothed.index(3) = vol;
128 out.index(4) = vol;
129 }
130
131 Adapter::Gradient1D<decltype(smoothed)> gradient1D (smoothed, 0, wrt_scanner);
132 out.index(3) = 0;
133 threaded_copy (gradient1D, out, 0, 3, 2);
134 if (progress) ++(*progress);
135 out.index(3) = 1;
136 gradient1D.set_axis (1);
137 threaded_copy (gradient1D, out, 0, 3, 2);
138 if (progress) ++(*progress);
139 out.index(3) = 2;
140 gradient1D.set_axis (2);
141 threaded_copy (gradient1D, out, 0, 3, 2);
142 if (progress) ++(*progress);
143
144 if (wrt_scanner) {
145 Transform transform (in);
146 for (auto l = Loop(0,3) (out); l; ++l)
147 out.row(3) = transform.image2scanner.linear() * Eigen::Vector3d (out.row(3));
148 }
149 }
150 }
151
152 protected:
155 const bool magnitude;
157 };
159 }
160}
161
162#endif
static constexpr uint8_t Float32
Definition: datatype.h:147
std::string message
Definition: base.h:66
Filter::Smooth smoother
Definition: gradient.h:153
const bool magnitude
Definition: gradient.h:155
vector< default_type > stdev
Definition: gradient.h:156
vector< Axis > axes_
Definition: header.h:361
static Image scratch(const Header &template_header, const std::string &label="scratch image")
Definition: image.h:195
implements a progress meter to provide feedback to the user
Definition: progressbar.h:58
const transform_type image2scanner
Definition: transform.h:43
#define DEBUG(msg)
Definition: exception.h:75
FORCE_INLINE LoopAlongAxes Loop()
Definition: loop.h:419
Definition: base.h:24
void threaded_copy(InputImageType &source, OutputImageType &destination, const vector< size_t > &axes, size_t num_axes_in_thread=1)
Definition: threaded_copy.h:43