Developer documentation
Version 3.0.3-105-gd3941f44
gradient_descent_bb.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __math_gradient_descent_bb_h__
18#define __math_gradient_descent_bb_h__
19
20#include <limits>
21#include <iostream>
22#include <fstream>
23#include <deque>
24#include <limits>
25#include <fstream>
26#include "math/check_gradient.h"
27
28namespace MR
29{
30 namespace Math
31 {
32
34 // @{
35
37 public:
38 template <typename ValueType>
39 inline bool operator() (Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& newx, const Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& x,
40 const Eigen::Matrix<ValueType, Eigen::Dynamic, 1>& g, ValueType step_size) {
41 assert (newx.size() == x.size());
42 assert (g.size() == x.size());
43 newx = x - step_size * g;
44 return !newx.isApprox(x);
45 }
46 };
47
49 template <class Function, class UpdateFunctor=LinearUpdateBB>
52 public:
53 using value_type = typename Function::value_type;
54
55 GradientDescentBB (Function& function, UpdateFunctor update_functor = LinearUpdateBB(), bool verbose = false) :
56 func (function),
57 update_func (update_functor),
58 x1 (func.size()),
59 x2 (func.size()),
60 x3 (func.size()),
61 g1 (func.size()),
62 g2 (func.size()),
63 g3 (func.size()),
64 nfeval (0),
65 niter (0),
67 delim (",") { }
68
69 value_type value () const { return f; }
70 const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& state () const { return x2; }
71 const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& gradient () const { return g2; }
72 value_type step_size () const { return dt; }
73 value_type gradient_norm () const { return normg; }
74 int function_evaluations () const { return nfeval; }
75
76 void be_verbose (bool v) { verbose = v; }
77 void precondition (const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& weights) {
78 preconditioner_weights = weights;
79 }
80
81 void run (const size_t max_iterations = 1000,
82 const value_type grad_tolerance = 1e-6,
83 std::streambuf* log_stream = nullptr)
84 {
85 std::ostream log_os(log_stream? log_stream : nullptr);
86 if (log_os){
87 log_os << "#iteration" << delim << "feval" << delim << "cost" << delim << "stepsize";
88 for ( ssize_t a = 0 ; a < x1.size() ; a++ )
89 log_os << delim + "x_" + str(a+1) ;
90 for ( ssize_t a = 0 ; a < x1.size() ; a++ )
91 log_os << delim + "g_" + str(a+1) ;
92 log_os << "\n" << std::flush;
93 }
94 init (log_os);
95
96 const value_type gradient_tolerance (grad_tolerance * normg);
97
98 DEBUG ("Gradient descent iteration: init; cost: " + str(f));
99
100 while (niter < max_iterations) {
101 bool retval = iterate (log_os);
102 DEBUG ("Gradient descent iteration: " + str(niter) + "; cost: " + str(f));
103 if (verbose){
104 CONSOLE ("iteration " + str (niter) + ": f = " + str (f) + ", |g| = " + str (normg) + ":");
105 CONSOLE (" x = [ " + str(x2.transpose()) + "]");
106 }
107
108 if (normg < gradient_tolerance) {
109 if (verbose)
110 CONSOLE ("normg (" + str(normg) + ") < gradient tolerance (" + str(gradient_tolerance) + ")");
111 return;
112 }
113
114 if (!retval){
115 if (verbose)
116 CONSOLE ("unchanged parameters");
117 return;
118 }
119 }
120 }
121
122 void init () {
123 std::ostream dummy (nullptr);
124 init (dummy);
125 }
126
127 void init (std::ostream& log_os) {
128 dt = func.init (x1);
130 normg = g1.norm();
131 if (normg == 0.0) {
132 x2 = x1;
133 g2 = g1;
134 return;
135 }
136 assert(std::isfinite(normg)); assert(!std::isnan(normg));
137 dt /= normg;
138 if (verbose) {
139 CONSOLE ("initialise: f = " + str (f) + ", |g| = " + str (normg) + ", step = " + str(dt) + ":");
140 CONSOLE (" x = [ " + str(x1.transpose()) + "]");
141 }
142 if (log_os) {
143 log_os << niter << delim << nfeval << delim << str(f) << delim << str(dt);
144 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(x1(i)); }
145 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(g1(i)); }
146 log_os << std::endl;
147 }
148
149 assert (std::isfinite(f)); assert (!std::isnan(f));
150 assert (std::isfinite (normg)); assert (!std::isnan(normg));
151
152 if (update_func (x2, x1, g1, dt)){
154 } else {
155 dt = 0.0;
156 return;
157 }
159 assert (std::isfinite (f)); assert (!std::isnan(f));
160 assert (std::isfinite(normg)); assert (!std::isnan(normg));
161 if (log_os) {
162 log_os << niter << delim << nfeval << delim << str(f) << delim << str(dt);
163 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(x2(i)); }
164 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(g2(i)); }
165 log_os << std::endl;
166 }
167 if (verbose) {
168 CONSOLE (" f = " + str (f) + ", |g| = " + str (normg) + ", step = " + str(dt) + ":");
169 CONSOLE (" x = [ " + str(x2.transpose()) + "]");
170 }
171 }
172
173 bool iterate () {
174 std::ostream dummy (nullptr);
175 return iterate (dummy);
176 }
177
178 bool iterate (std::ostream& log_os) {
179 assert (std::isfinite (normg));
180 if ((normg == 0.0) or !update_func (x3, x2, g2, dt))
181 return false;
182
184 x2.swap(x3);
185 x1.swap(x3);
186 g2.swap(g3);
187 g1.swap(g3);
188 ++niter;
189 if (log_os) {
190 log_os << niter << delim << nfeval << delim << str(f) << delim << str(dt);
191 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(x2(i)); }
192 for (ssize_t i=0; i< x2.size(); ++i){ log_os << delim << str(g2(i)); }
193 log_os << std::endl;
194 }
196 return true;
197 }
198
199 protected:
200 Function& func;
201 UpdateFunctor update_func;
202 Eigen::Matrix<value_type, Eigen::Dynamic, 1> x1, x2, x3, g1, g2, g3, preconditioner_weights;
204 size_t nfeval;
205 size_t niter;
207 std::string delim;
208
209 value_type evaluate_func (const Eigen::Matrix<value_type, Eigen::Dynamic, 1>& newx, Eigen::Matrix<value_type, Eigen::Dynamic, 1>& newg, bool verbose = false) {
210 ++nfeval;
211 value_type cost = func (newx, newg);
212 if (!std::isfinite (cost))
213 throw Exception ("cost function is NaN or Inf!");
214 if (verbose){
215 CONSOLE (" << eval " + str(nfeval) + ", f = " + str (cost) + " >>");
216 CONSOLE (" << newx = [ " + str(newx.transpose()) + "]");
217 CONSOLE (" << newg = [ " + str(newg.transpose()) + "]");
218 }
219 return cost;
220 }
221
223 if (preconditioner_weights.size()) {
224 g2.array() *= preconditioner_weights.array();
225 }
226 normg = g2.norm();
227 assert((g2-g1).norm()>0.0);
228 dt = (x2-x1).norm()/(g2-g1).norm();
229 }
230 };
232 }
233}
234
235#endif
Computes the minimum of a function using a Barzilai Borwein gradient descent approach....
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > x2
value_type evaluate_func(const Eigen::Matrix< value_type, Eigen::Dynamic, 1 > &newx, Eigen::Matrix< value_type, Eigen::Dynamic, 1 > &newg, bool verbose=false)
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > preconditioner_weights
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > g1
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > g3
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > g2
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > x3
Eigen::Matrix< value_type, Eigen::Dynamic, 1 > x1
bool operator()(Eigen::Matrix< ValueType, Eigen::Dynamic, 1 > &newx, const Eigen::Matrix< ValueType, Eigen::Dynamic, 1 > &x, const Eigen::Matrix< ValueType, Eigen::Dynamic, 1 > &g, ValueType step_size)
#define DEBUG(msg)
Definition: exception.h:75
#define CONSOLE(msg)
Definition: exception.h:71
constexpr double e
Definition: math.h:39
#define NOMEMALIGN
Definition: memory.h:22
MR::default_type value_type
Definition: typedefs.h:33
Definition: base.h:24
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247
#define MEMALIGN(...)
Definition: types.h:185