Developer documentation
Version 3.0.3-105-gd3941f44
gradient_descent.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __math_gradient_descent_h__
18#define __math_gradient_descent_h__
19
20#include <limits>
21
22namespace MR
23{
24 namespace Math
25 {
26
28 // @{
29
30
31 namespace {
32
33 class LinearUpdate { NOMEMALIGN
34 public:
35 template <typename ValueType>
36 inline bool operator() (Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& newx, const Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& x,
37 const Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& g, ValueType step_size) {
38 bool changed = false;
39 for (ssize_t n = 0; n < x.size(); ++n) {
40 newx[n] = x[n] - step_size * g[n];
41 if (newx[n] != x[n])
42 changed = true;
43 }
44 return changed;
45 }
46 };
47
48 }
49
51 template <class Function, class UpdateFunctor=LinearUpdate>
54 public:
55 using value_type = typename Function::value_type;
56
57 GradientDescent (Function& function, UpdateFunctor update_functor = LinearUpdate(), value_type step_size_upfactor = 3.0, value_type step_size_downfactor = 0.1, bool verbose = false) :
58 func (function),
59 update_func (update_functor),
60 step_up (step_size_upfactor),
61 step_down (step_size_downfactor),
63 delim (","),
64 niter (0),
65 x (func.size()),
66 x2 (func.size()),
67 g (func.size()),
68 g2 (func.size()) { }
69
70
71 value_type value () const throw () { return f; }
72 const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& state () const throw () { return x; }
73 const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& gradient () const throw () { return g; }
74 value_type step_size () const { return dt; }
75 value_type gradient_norm () const throw () { return normg; }
76 int function_evaluations () const throw () { return nfeval; }
77
78 void be_verbose(bool v) { verbose = v; }
79 void precondition (const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& weights) {
80 preconditioner_weights = weights;
81 }
82
83 void run (const size_t max_iterations = 1000,
84 const value_type grad_tolerance = 1e-6,
85 std::streambuf* log_stream = nullptr) {
86 std::ostream log_os(log_stream? log_stream : nullptr);
87 if (log_os){
88 log_os << "#iteration" << delim << "feval" << delim << "cost" << delim << "stepsize";
89 for ( ssize_t a = 0 ; a < x.size() ; a++ )
90 log_os << delim + "x_" + str(a+1) ;
91 for ( ssize_t a = 0 ; a < x.size() ; a++ )
92 log_os << delim + "g_" + str(a+1) ;
93 log_os << "\n" << std::flush;
94 }
95 init (log_os);
96
97 const value_type gradient_tolerance (grad_tolerance * normg);
98
99 DEBUG ("Gradient descent iteration: init; cost: " + str(f));
100
101 while (niter < max_iterations) {
102 bool retval = iterate (log_os);
103 DEBUG ("Gradient descent iteration: " + str(niter) + "; cost: " + str(f));
104 if (verbose) {
105 CONSOLE ("iteration " + str (niter) + ": f = " + str (f) + ", |g| = " + str (normg) + ":");
106 CONSOLE (" x = [ " + str(x.transpose()) + "]");
107 }
108
109 if (normg < gradient_tolerance) {
110 if (verbose)
111 CONSOLE ("normg (" + str(normg) + ") < gradient tolerance (" + str(gradient_tolerance) + ")");
112 return;
113 }
114
115 if (!retval){
116 if (verbose)
117 CONSOLE ("unchanged parameters");
118 return;
119 }
120 }
121 }
122
123 void init () {
124 std::ostream dummy (nullptr);
125 init (dummy);
126 }
127
128 void init (std::ostream& log_os) {
129 dt = func.init (x);
130 nfeval = 0;
131 f = evaluate_func (x, g, verbose);
133 normg = g.norm();
134 assert(std::isfinite(normg));
135 assert(!std::isnan(normg));
136 dt /= normg;
137 if (verbose) {
138 CONSOLE ("initialise: f = " + str (f) + ", |g| = " + str (normg) + ":");
139 CONSOLE (" x = [ " + str(x.transpose()) + "]");
140 }
141 if (normg == 0.0)
142 return;
143 assert (std::isfinite (f));
144 assert (!std::isnan(f));
145 assert (std::isfinite (normg));
146 assert (!std::isnan(normg));
147 if (log_os) {
148 log_os << niter << delim << nfeval << delim << str(f) << delim << str(dt);
149 for (ssize_t i=0; i< x.size(); ++i){ log_os << delim << str(x(i)); }
150 for (ssize_t i=0; i< x.size(); ++i){ log_os << delim << str(g(i)); }
151 log_os << std::endl;
152 }
153 }
154
155 bool iterate () {
156 std::ostream dummy (nullptr);
157 return iterate (dummy);
158 }
159
160 bool iterate (std::ostream& log_os) {
161 // assert (normg != 0.0);
162 assert (std::isfinite (normg));
163
164
165 while (normg != 0.0) {
166 if (!update_func (x2, x, g, dt))
167 return false;
168
170
171 // quadratic minimum:
172 value_type step_length = step_unscaled*dt;
173 value_type denom = 2.0 * (normg*step_length + f2 - f);
174 value_type quadratic_minimum = denom > 0.0 ? normg * step_length / denom : step_up;
175
176 if (quadratic_minimum < step_down) quadratic_minimum = step_down;
177 if (quadratic_minimum > step_up) quadratic_minimum = step_up;
178
179 if (f2 < f) {
180 ++niter;
181 dt *= quadratic_minimum;
182 f = f2;
183 x.swap (x2);
184 g.swap (g2);
185 if (log_os) {
186 log_os << niter << delim << nfeval << delim << str(f) << delim << str(dt);
187 for (ssize_t i=0; i< x.size(); ++i){ log_os << delim << str(x(i)); }
188 for (ssize_t i=0; i< x.size(); ++i){ log_os << delim << str(g(i)); }
189 log_os << std::endl;
190 }
192 return true;
193 }
194
195 if (quadratic_minimum >= 1.0)
196 quadratic_minimum = 0.5;
197 dt *= quadratic_minimum;
198
199 if (dt <= 0.0)
200 return false;
201 }
202 return false;
203 }
204
205 protected:
206 Function& func;
207 UpdateFunctor update_func;
210 std::string delim;
211 size_t niter;
212 Eigen::Matrix<value_type, Eigen::Dynamic, 1> x, x2, g, g2, preconditioner_weights;
214 size_t nfeval;
215
216 value_type evaluate_func (const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& newx,
217 Eigen::Matrix<value_type, Eigen::Dynamic, 1>& newg,
218 bool verbose = false) {
219 nfeval++;
220 value_type cost = func (newx, newg);
221 if (!std::isfinite (cost))
222 throw Exception ("cost function is NaN or Inf!");
223 if (verbose)
224 CONSOLE (" << eval " + str(nfeval) + ", f = " + str (cost) + " >>");
225 return cost;
226 }
227
228
230 normg = step_unscaled = g.norm();
231 assert(std::isfinite(normg));
232 if (normg > 0.0){
233 if (preconditioner_weights.size()) {
234 value_type g_projected = 0.0;
235 for (ssize_t n = 0; n < g.size(); ++n) {
236 g_projected += preconditioner_weights[n] * Math::pow2(g[n]);
237 }
238 g.array() *= preconditioner_weights.array();
239 normg = g_projected / normg;
240 assert(std::isfinite(normg));
241 }
242 }
243 }
244
245 };
247 }
248}
249
250#endif
251
Computes the minimum of a function using a gradient descent approach.
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > g
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > g2
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > preconditioner_weights
value_type evaluate_func(const Eigen::Matrix< value_type, Eigen::Dynamic, 1 > &newx, Eigen::Matrix< value_type, Eigen::Dynamic, 1 > &newg, bool verbose=false)
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > x
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > x2
#define DEBUG(msg)
Definition: exception.h:75
#define CONSOLE(msg)
Definition: exception.h:71
constexpr T pow2(const T &v)
Definition: math.h:53
constexpr double e
Definition: math.h:39
#define NOMEMALIGN
Definition: memory.h:22
MR::default_type value_type
Definition: typedefs.h:33
Definition: base.h:24
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
#define MEMALIGN(...)
Definition: types.h:185