NumCpp  2.4.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Brent.hpp
Go to the documentation of this file.
1 #pragma once
33 
35 #include "NumCpp/Core/Types.hpp"
37 
38 #include <cmath>
39 #include <functional>
40 #include <utility>
41 
42 namespace nc
43 {
44  namespace roots
45  {
46  //================================================================================
47  // Class Description:
50  class Brent : public Iteration
51  {
52  public:
53  //============================================================================
54  // Method Description:
60  Brent(const double epsilon,
61  std::function<double(double)> f) noexcept :
62  Iteration(epsilon),
63  f_(std::move(f))
64  {}
65 
66  //============================================================================
67  // Method Description:
74  Brent(const double epsilon,
75  const uint32 maxNumIterations,
76  std::function<double(double)> f) noexcept :
77  Iteration(epsilon, maxNumIterations),
78  f_(std::move(f))
79  {}
80 
81  //============================================================================
82  // Method Description:
85  ~Brent() override = default;
86 
87  //============================================================================
88  // Method Description:
95  double solve(double a, double b)
96  {
98 
99  double fa = f_(a);
100  double fb = f_(b);
101 
102  checkAndFixAlgorithmCriteria(a, b, fa, fb);
103 
104  double lastB = a; // b_{k-1}
105  double lastFb = fa;
106  double s = DtypeInfo<double>::max();
107  double fs = DtypeInfo<double>::max();
108  double penultimateB = a; // b_{k-2}
109 
110  bool bisection = true;
111  while (std::fabs(fb) > epsilon_ && std::fabs(fs) > epsilon_ && std::fabs(b - a) > epsilon_)
112  {
113  if (useInverseQuadraticInterpolation(fa, fb, lastFb))
114  {
115  s = calculateInverseQuadraticInterpolation(a, b, lastB, fa, fb, lastFb);
116  }
117  else
118  {
119  s = calculateSecant(a, b, fa, fb);
120  }
121 
122  if (useBisection(bisection, b, lastB, penultimateB, s))
123  {
124  s = calculateBisection(a, b);
125  bisection = true;
126  }
127  else
128  {
129  bisection = false;
130  }
131 
132  fs = f_(s);
133  penultimateB = lastB;
134  lastB = b;
135 
136  if (fa * fs < 0)
137  {
138  b = s;
139  }
140  else {
141  a = s;
142  }
143 
144  fa = f_(a);
145  lastFb = fb;
146  fb = f_(b);
147  checkAndFixAlgorithmCriteria(a, b, fa, fb);
148 
150  }
151 
152  return fb < fs ? b : s;
153  }
154 
155  private:
156  //============================================================================
157  const std::function<double(double)> f_;
158 
159  //============================================================================
160  // Method Description:
167  static double calculateBisection(const double a, const double b) noexcept
168  {
169  return 0.5 * (a + b);
170  }
171 
172  //============================================================================
173  // Method Description:
182  static double calculateSecant(const double a, const double b, const double fa, const double fb) noexcept
183  {
184  //No need to check division by 0, in this case the method returns NAN which is taken care by useSecantMethod method
185  return b - fb * (b - a) / (fb - fa);
186  }
187 
188  //============================================================================
189  // Method Description:
200  static double calculateInverseQuadraticInterpolation(const double a, const double b, const double lastB,
201  const double fa, const double fb, const double lastFb) noexcept
202  {
203  return a * fb * lastFb / ((fa - fb) * (fa - lastFb)) +
204  b * fa * lastFb / ((fb - fa) * (fb - lastFb)) +
205  lastB * fa * fb / ((lastFb - fa) * (lastFb - fb));
206  }
207 
208  //============================================================================
209  // Method Description:
217  static bool useInverseQuadraticInterpolation(const double fa, const double fb, const double lastFb) noexcept
218  {
219  return fa != lastFb && fb != lastFb;
220  }
221 
222  //============================================================================
223  // Method Description:
231  static void checkAndFixAlgorithmCriteria(double &a, double &b, double &fa, double &fb) noexcept
232  {
233  //Algorithm works in range [a,b] if criteria f(a)*f(b) < 0 and f(a) > f(b) is fulfilled
234  if (std::fabs(fa) < std::fabs(fb))
235  {
236  std::swap(a, b);
237  std::swap(fa, fb);
238  }
239  }
240 
241  //============================================================================
242  // Method Description:
252  bool useBisection(const bool bisection, const double b, const double lastB,
253  const double penultimateB, const double s) const noexcept
254  {
255  const double DELTA = epsilon_ + std::numeric_limits<double>::min();
256 
257  return (bisection && std::fabs(s - b) >= 0.5 * std::fabs(b - lastB)) || //Bisection was used in last step but |s-b|>=|b-lastB|/2 <- Interpolation step would be to rough, so still use bisection
258  (!bisection && std::fabs(s - b) >= 0.5 * std::fabs(lastB - penultimateB)) || //Interpolation was used in last step but |s-b|>=|lastB-penultimateB|/2 <- Interpolation step would be to small
259  (bisection && std::fabs(b - lastB) < DELTA) || //If last iteration was using bisection and difference between b and lastB is < delta use bisection for next iteration
260  (!bisection && std::fabs(lastB - penultimateB) < DELTA); //If last iteration was using interpolation but difference between lastB ond penultimateB is < delta use biscetion for next iteration
261  }
262  };
263  } // namespace roots
264 } // namespace nc
nc::roots::Iteration::Iteration
Iteration(double epsilon) noexcept
Definition: Iteration.hpp:55
nc::roots::Brent
Definition: Brent.hpp:50
nc::roots::Iteration
ABC for iteration classes to derive from.
Definition: Iteration.hpp:46
nc::roots::Brent::Brent
Brent(const double epsilon, const uint32 maxNumIterations, std::function< double(double)> f) noexcept
Definition: Brent.hpp:74
nc::roots::Brent::~Brent
~Brent() override=default
nc::roots::Iteration::incrementNumberOfIterations
void incrementNumberOfIterations()
Definition: Iteration.hpp:104
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:40
nc::roots::Iteration::epsilon_
const double epsilon_
Definition: Iteration.hpp:114
nc::roots::Brent::Brent
Brent(const double epsilon, std::function< double(double)> f) noexcept
Definition: Brent.hpp:60
nc::roots::Brent::solve
double solve(double a, double b)
Definition: Brent.hpp:95
Iteration.hpp
nc::roots::Iteration::resetNumberOfIterations
void resetNumberOfIterations() noexcept
Definition: Iteration.hpp:93
nc
Definition: Coordinate.hpp:44
nc::swap
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
DtypeInfo.hpp
Types.hpp
nc::random::f
dtype f(dtype inDofN, dtype inDofD)
Definition: f.hpp:56
nc::DtypeInfo::max
static constexpr dtype max() noexcept
Definition: DtypeInfo.hpp:110
nc::min
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:45