Developer documentation
Version 3.0.3-105-gd3941f44
quadratic_line_search.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __math_quadratic_line_search_h__
18#define __math_quadratic_line_search_h__
19
20
21#include "progressbar.h"
22
23
24namespace MR
25{
26 namespace Math
27 {
32
71 template <typename ValueType>
74
75 public:
76
77 // TODO Return error code if converging toward a maxima instead of a minima
78 // TODO Separate return codes for above & below domain
79 enum return_t {SUCCESS, EXECUTING, OUTSIDE_BOUNDS, NONCONVEX, NONCONVERGING};
80
81 QuadraticLineSearch (const ValueType lower_bound, const ValueType upper_bound) :
82 init_lower (lower_bound),
83 init_mid (0.5 * (lower_bound + upper_bound)),
84 init_upper (upper_bound),
85 value_tolerance (0.001 * (upper_bound - lower_bound)),
86 function_tolerance (0.0),
87 exit_outside_bounds (true),
88 max_iters (50),
89 status (SUCCESS) { }
90
91
92 void set_lower_bound (const ValueType i) { init_lower = i; }
93 void set_init_estimate (const ValueType i) { init_mid = i; }
94 void set_upper_bound (const ValueType i) { init_upper = i; }
95 void set_value_tolerance (const ValueType i) { value_tolerance = i; }
96 void set_function_tolerance (const ValueType i) { function_tolerance = i; }
97 void set_exit_if_outside_bounds (const bool i) { exit_outside_bounds = i; }
98 void set_max_iterations (const size_t i) { max_iters = i; }
99 void set_message (const std::string& i) { message = i; }
100
101 return_t get_status() const { return status; }
102
103
104 template <class Functor>
105 ValueType operator() (Functor& functor) const
106 {
107
108 status = EXECUTING;
109
110 std::unique_ptr<ProgressBar> progress (message.size() ? new ProgressBar (message) : nullptr);
111
112 ValueType l = init_lower, m = init_mid, u = init_upper;
113 ValueType fl = functor (l), fm = functor (m), fu = functor (u);
114 // TODO Need to test if these bounds are producing a NaN CF
115 size_t iters = 0;
116
117 while (iters++ < max_iters) {
118
119 // TODO When testing for nonconvexity, the problem may also arise due to quantisation
120 // in the cost function
121 // Would like to have a fractional threshold on the cost function i.e. if it's really flat,
122 // return successfully
123 // Difficult to do this without knowledge of the cost function
124 if (fm > (fl + ((fu-fl)*(m-l)/(u-l)))) {
125 if ((std::min(m-l, u-m) < value_tolerance) || (abs((fu-fl)/(0.5*(fu+fl))) < function_tolerance)) {
126 status = SUCCESS;
127 return m;
128 }
129 status = NONCONVEX;
130 return NaN;
131 }
132
133 const ValueType sl = (fm-fl) / (m-l);
134 const ValueType su = (fu-fm) / (u-m);
135
136 const ValueType n = (0.5*(l+m)) - ((sl*(u-l)) / (2.0*(su-sl)));
137
138 const ValueType fn = functor (n);
139 if (!std::isfinite(fn))
140 return m;
141
142 if (n < l) {
143 if (exit_outside_bounds) {
144 status = OUTSIDE_BOUNDS;
145 return NaN;
146 }
147 u = m; fu = fm;
148 m = l; fm = fl;
149 l = n; fl = fn;
150 } else if (n < m) {
151 if (fn > fm) {
152 l = n; fl = fn;
153 } else {
154 u = m; fu = fm;
155 m = n; fm = fn;
156 }
157 } else if (n == m) {
158 return n;
159 } else if (n < u) {
160 if (fn > fm) {
161 u = n; fu = fn;
162 } else {
163 l = m; fl = fm;
164 m = n; fm = fn;
165 }
166 } else {
167 if (exit_outside_bounds) {
168 status = OUTSIDE_BOUNDS;
169 return NaN;
170 }
171 l = m; fl = fm;
172 m = u; fm = fu;
173 u = n; fu = fn;
174 }
175
176 if (progress)
177 ++(*progress);
178
179 if ((u-l) < value_tolerance) {
180 status = SUCCESS;
181 return m;
182 }
183
184 }
185
186 status = NONCONVERGING;
187 return NaN;
188
189 }
190
191
192
193 template <class Functor>
194 ValueType verbose (Functor& functor) const
195 {
196
197 status = EXECUTING;
198
199 ValueType l = init_lower, m = init_mid, u = init_upper;
200 ValueType fl = functor (l), fm = functor (m), fu = functor (u);
201 std::cerr << "Initialising quadratic line search\n";
202 std::cerr << " Lower Mid Upper\n";
203 std::cerr << "Pos " << str (l) << " " << str(m) << " " << str(u) << "\n";
204 std::cerr << "Value " << str (fl) << " " << str(fm) << " " << str(fu) << "\n";
205 size_t iters = 0;
206
207 while (iters++ < max_iters) {
208
209 if (fm > (fl + ((fu-fl)*(m-l)/(u-l)))) {
210 if (std::min(m-l, u-m) < value_tolerance) {
211 std::cerr << "Returning due to nonconvexity, through successfully\n";
212 status = SUCCESS;
213 return m;
214 }
215 status = NONCONVEX;
216 std::cerr << "Returning due to nonconvexity, unsuccessfully\n";
217 return NaN;
218 }
219
220 const ValueType sl = (fm-fl) / (m-l);
221 const ValueType su = (fu-fm) / (u-m);
222
223 const ValueType n = (0.5*(l+m)) - ((sl*(u-l)) / (2.0*(su-sl)));
224
225 const ValueType fn = functor (n);
226
227 std::cerr << " New point " << str(n) << ", value " << str(fn) << "\n";
228
229 if (n < l) {
230 if (exit_outside_bounds) {
231 status = OUTSIDE_BOUNDS;
232 return NaN;
233 }
234 u = m; fu = fm;
235 m = l; fm = fl;
236 l = n; fl = fn;
237 } else if (n < m) {
238 if (fn > fm) {
239 l = n; fl = fn;
240 } else {
241 u = m; fu = fm;
242 m = n; fm = fn;
243 }
244 } else if (n == m) {
245 return n;
246 } else if (n < u) {
247 if (fn > fm) {
248 u = n; fu = fn;
249 } else {
250 l = m; fl = fm;
251 m = n; fm = fn;
252 }
253 } else {
254 if (exit_outside_bounds) {
255 status = OUTSIDE_BOUNDS;
256 return NaN;
257 }
258 l = m; fl = fm;
259 m = u; fm = fu;
260 u = n; fu = fn;
261 }
262
263 std::cerr << "\n";
264 std::cerr << "Pos " << str (l) << " " << str(m) << " " << str(u) << "\n";
265 std::cerr << "Value " << str (fl) << " " << str(fm) << " " << str(fu) << "\n";
266
267 if ((u-l) < value_tolerance) {
268 status = SUCCESS;
269 std::cerr << "Returning successfully\n";
270 return m;
271 }
272
273 }
274
275 status = NONCONVERGING;
276 std::cerr << "Returning due to too many iterations\n";
277 return NaN;
278 }
279
280
281
282 private:
283 ValueType init_lower, init_mid, init_upper, value_tolerance, function_tolerance;
284 bool exit_outside_bounds;
285 size_t max_iters;
286 std::string message;
287
288 mutable return_t status;
289 };
290
291
292
293
294 }
295}
296
297#endif
298
Computes the minimum of a 1D function using a quadratic line search.
implements a progress meter to provide feedback to the user
Definition: progressbar.h:58
Definition: base.h:24
constexpr std::enable_if< std::is_arithmetic< X >::value &&std::is_unsigned< X >::value, X >::type abs(X x)
Definition: types.h:297
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
constexpr default_type NaN
Definition: types.h:230
#define MEMALIGN(...)
Definition: types.h:185
std::remove_reference< Functor >::type & functor
Definition: thread.h:215