Developer documentation
Version 3.0.3-105-gd3941f44
check_gradient.h
Go to the documentation of this file.
1/* Copyright (c) 2008-2022 the MRtrix3 contributors.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 * Covered Software is provided under this License on an "as is"
8 * basis, without warranty of any kind, either expressed, implied, or
9 * statutory, including, without limitation, warranties that the
10 * Covered Software is free of defects, merchantable, fit for a
11 * particular purpose or non-infringing.
12 * See the Mozilla Public License v. 2.0 for more details.
13 *
14 * For more details, see http://www.mrtrix.org/.
15 */
16
17#ifndef __math_check_gradient_h__
18#define __math_check_gradient_h__
19
20#include "debug.h"
21#include "datatype.h"
22
23namespace MR {
24 namespace Math {
25
26 template <class Function>
27 Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, Eigen::Dynamic> check_function_gradient (
28 Function& function,
29 Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> x,
30 typename Function::value_type increment,
31 bool show_hessian = false,
32 Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> conditioner = Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1>())
33 {
34 using value_type = typename Function::value_type;
35 const size_t N = function.size();
36 Eigen::Matrix<value_type, Eigen::Dynamic, 1> g (N);
37
38 CONSOLE ("checking gradient for cost function over " + str(N) +
39 " parameters of type " + DataType::from<value_type>().specifier());
40 value_type step_size = function.init (g);
41 CONSOLE ("cost function suggests initial step size = " + str(step_size));
42 CONSOLE ("cost function suggests initial position at [ " + str(g.transpose()) + "]");
43
44 CONSOLE ("checking gradient at position [ " + str(x.transpose()) + "]:");
45 Eigen::Matrix<value_type, Eigen::Dynamic, 1> g0 (N);
46 value_type f0 = function (x, g0);
47 CONSOLE (" cost function = " + str(f0));
48 CONSOLE (" gradient from cost function = [ " + str(g0.transpose()) + "]");
49
50 Eigen::Matrix<value_type, Eigen::Dynamic, 1> g_fd (N);
51 Eigen::Matrix<value_type, Eigen::Dynamic, Eigen::Dynamic> hessian;
52 if (show_hessian) {
53 hessian.resize(N, N);
54 if (conditioner.size()){
55 assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
56 for (size_t n = 0; n < N; ++n)
57 conditioner[n] = std::sqrt(conditioner[n]);
58 }
59 }
60
61 for (size_t n = 0; n < N; ++n) {
62 value_type old_x = x[n];
63 value_type inc = increment;
64 if (conditioner.size()){
65 assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
66 inc *= conditioner[n];
67 }
68
69 x[n] += inc;
70 value_type f1 = function (x, g);
71 if (show_hessian) {
72 if (conditioner.size())
73 g.cwiseProduct(conditioner);
74 hessian.col(n) = g;
75 }
76
77 x[n] = old_x - inc;
78 value_type f2 = function (x, g);
79 g_fd[n] = (f1-f2) / (2.0*inc);
80 x[n] = old_x;
81 if (show_hessian) {
82 if (conditioner.size())
83 g.cwiseProduct(conditioner);
84 hessian.col(n) -= g;
85 }
86
87 }
88
89 CONSOLE ("gradient by central finite difference = [ " + str(g_fd.transpose()) + "]");
90 CONSOLE ("normalised dot product = " + str(g_fd.dot(g0) / g_fd.squaredNorm()));
91
92 if (show_hessian) {
93 hessian /= 4.0*increment;
94 for (size_t j = 0; j < N; ++j) {
95 size_t i;
96 for (i = 0; i < j; ++i)
97 hessian(i,j) = hessian(j,i);
98 for (; i < N; ++i)
99 hessian(i,j) += hessian(j,i);
100 }
101 // CONSOLE ("hessian = [ " + str(hessian) + "]");
102 MAT(hessian);
103 CONSOLE("\033[00;34mcondition number: " + str(condition_number (hessian))+"\033[0m");
104 }
105 return hessian;
106 }
107 }
108}
109
110#endif
#define CONSOLE(msg)
Definition: exception.h:71
#define MAT(variable)
Prints a matrix name and in the following line its formatted value.
Definition: debug.h:45
MR::default_type value_type
Definition: typedefs.h:33
default_type condition_number(const M &data)
Eigen::Matrix< typename Function::value_type, Eigen::Dynamic, Eigen::Dynamic > check_function_gradient(Function &function, Eigen::Matrix< typename Function::value_type, Eigen::Dynamic, 1 > x, typename Function::value_type increment, bool show_hessian=false, Eigen::Matrix< typename Function::value_type, Eigen::Dynamic, 1 > conditioner=Eigen::Matrix< typename Function::value_type, Eigen::Dynamic, 1 >())
Definition: base.h:24
std::string str(const T &value, int precision=0)
Definition: mrtrix.h:247