NumCpp  2.9.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Brent.hpp
Go to the documentation of this file.
1 #pragma once
34 
35 #include <cmath>
36 #include <functional>
37 #include <utility>
38 
40 #include "NumCpp/Core/Types.hpp"
42 
43 namespace nc
44 {
45  namespace roots
46  {
47  //================================================================================
48  // Class Description:
51  class Brent : public Iteration
52  {
53  public:
54  //============================================================================
55  // Method Description:
61  Brent(const double epsilon, std::function<double(double)> f) noexcept :
62  Iteration(epsilon),
63  f_(std::move(f))
64  {
65  }
66 
67  //============================================================================
68  // Method Description:
75  Brent(const double epsilon, const uint32 maxNumIterations, std::function<double(double)> f) noexcept :
76  Iteration(epsilon, maxNumIterations),
77  f_(std::move(f))
78  {
79  }
80 
81  //============================================================================
82  // Method Description:
85  ~Brent() override = default;
86 
87  //============================================================================
88  // Method Description:
95  double solve(double a, double b)
96  {
98 
99  double fa = f_(a);
100  double fb = f_(b);
101 
102  checkAndFixAlgorithmCriteria(a, b, fa, fb);
103 
104  double lastB = a; // b_{k-1}
105  double lastFb = fa;
106  double s = DtypeInfo<double>::max();
107  double fs = DtypeInfo<double>::max();
108  double penultimateB = a; // b_{k-2}
109 
110  bool bisection = true;
111  while (std::fabs(fb) > epsilon_ && std::fabs(fs) > epsilon_ && std::fabs(b - a) > epsilon_)
112  {
113  if (useInverseQuadraticInterpolation(fa, fb, lastFb))
114  {
115  s = calculateInverseQuadraticInterpolation(a, b, lastB, fa, fb, lastFb);
116  }
117  else
118  {
119  s = calculateSecant(a, b, fa, fb);
120  }
121 
122  if (useBisection(bisection, b, lastB, penultimateB, s))
123  {
124  s = calculateBisection(a, b);
125  bisection = true;
126  }
127  else
128  {
129  bisection = false;
130  }
131 
132  fs = f_(s);
133  penultimateB = lastB;
134  lastB = b;
135 
136  if (fa * fs < 0)
137  {
138  b = s;
139  }
140  else
141  {
142  a = s;
143  }
144 
145  fa = f_(a);
146  lastFb = fb;
147  fb = f_(b);
148  checkAndFixAlgorithmCriteria(a, b, fa, fb);
149 
151  }
152 
153  return fb < fs ? b : s;
154  }
155 
156  private:
157  //============================================================================
158  const std::function<double(double)> f_;
159 
160  //============================================================================
161  // Method Description:
168  static double calculateBisection(const double a, const double b) noexcept
169  {
170  return 0.5 * (a + b);
171  }
172 
173  //============================================================================
174  // Method Description:
183  static double calculateSecant(const double a, const double b, const double fa, const double fb) noexcept
184  {
185  // No need to check division by 0, in this case the method returns NAN which is taken care by
186  // useSecantMethod method
187  return b - fb * (b - a) / (fb - fa);
188  }
189 
190  //============================================================================
191  // Method Description:
202  static double calculateInverseQuadraticInterpolation(const double a,
203  const double b,
204  const double lastB,
205  const double fa,
206  const double fb,
207  const double lastFb) noexcept
208  {
209  return a * fb * lastFb / ((fa - fb) * (fa - lastFb)) + b * fa * lastFb / ((fb - fa) * (fb - lastFb)) +
210  lastB * fa * fb / ((lastFb - fa) * (lastFb - fb));
211  }
212 
213  //============================================================================
214  // Method Description:
222  static bool useInverseQuadraticInterpolation(const double fa, const double fb, const double lastFb) noexcept
223  {
224  return fa != lastFb && fb != lastFb;
225  }
226 
227  //============================================================================
228  // Method Description:
236  static void checkAndFixAlgorithmCriteria(double &a, double &b, double &fa, double &fb) noexcept
237  {
238  // Algorithm works in range [a,b] if criteria f(a)*f(b) < 0 and f(a) > f(b) is fulfilled
239  if (std::fabs(fa) < std::fabs(fb))
240  {
241  std::swap(a, b);
242  std::swap(fa, fb);
243  }
244  }
245 
246  //============================================================================
247  // Method Description:
257  bool useBisection(const bool bisection,
258  const double b,
259  const double lastB,
260  const double penultimateB,
261  const double s) const noexcept
262  {
263  const double DELTA = epsilon_ + std::numeric_limits<double>::min();
264 
265  return (bisection &&
266  std::fabs(s - b) >=
267  0.5 *
268  std::fabs(b - lastB)) || // Bisection was used in last step but |s-b|>=|b-lastB|/2 <-
269  // Interpolation step would be to rough, so still use bisection
270  (!bisection &&
271  std::fabs(s - b) >=
272  0.5 * std::fabs(lastB - penultimateB)) || // Interpolation was used in last step but
273  // |s-b|>=|lastB-penultimateB|/2 <- Interpolation
274  // step would be to small
275  (bisection &&
276  std::fabs(b - lastB) < DELTA) || // If last iteration was using bisection and difference between
277  // b and lastB is < delta use bisection for next iteration
278  (!bisection && std::fabs(lastB - penultimateB) <
279  DELTA); // If last iteration was using interpolation but difference between
280  // lastB ond penultimateB is < delta use biscetion for next iteration
281  }
282  };
283  } // namespace roots
284 } // namespace nc
static constexpr dtype max() noexcept
Definition: DtypeInfo.hpp:110
Definition: Brent.hpp:52
Brent(const double epsilon, const uint32 maxNumIterations, std::function< double(double)> f) noexcept
Definition: Brent.hpp:75
double solve(double a, double b)
Definition: Brent.hpp:95
~Brent() override=default
Brent(const double epsilon, std::function< double(double)> f) noexcept
Definition: Brent.hpp:61
ABC for iteration classes to derive from.
Definition: Iteration.hpp:48
Iteration(double epsilon) noexcept
Definition: Iteration.hpp:56
const double epsilon_
Definition: Iteration.hpp:118
void resetNumberOfIterations() noexcept
Definition: Iteration.hpp:96
void incrementNumberOfIterations()
Definition: Iteration.hpp:107
dtype f(GeneratorType &generator, dtype inDofN, dtype inDofD)
Definition: f.hpp:58
Definition: Coordinate.hpp:45
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
std::uint32_t uint32
Definition: Types.hpp:40