NumCpp  2.11.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Brent.hpp
Go to the documentation of this file.
1 #pragma once
34 
35 #include <cmath>
36 #include <functional>
37 #include <utility>
38 
40 #include "NumCpp/Core/Types.hpp"
43 
44 namespace nc::roots
45 {
46  //================================================================================
47  // Class Description:
50  class Brent : public Iteration
51  {
52  public:
53  //============================================================================
54  // Method Description:
60  Brent(const double epsilon, std::function<double(double)> f) noexcept :
61  Iteration(epsilon),
62  f_(std::move(f))
63  {
64  }
65 
66  //============================================================================
67  // Method Description:
74  Brent(const double epsilon, const uint32 maxNumIterations, std::function<double(double)> f) noexcept :
75  Iteration(epsilon, maxNumIterations),
76  f_(std::move(f))
77  {
78  }
79 
80  //============================================================================
81  // Method Description:
84  ~Brent() override = default;
85 
86  //============================================================================
87  // Method Description:
94  double solve(double a, double b)
95  {
97 
98  double fa = f_(a);
99  double fb = f_(b);
100 
101  checkAndFixAlgorithmCriteria(a, b, fa, fb);
102 
103  double lastB = a; // b_{k-1}
104  double lastFb = fa;
105  double s = DtypeInfo<double>::max();
106  double fs = DtypeInfo<double>::max();
107  double penultimateB = a; // b_{k-2}
108 
109  bool bisection = true;
110  while (std::fabs(fb) > epsilon_ && std::fabs(fs) > epsilon_ && std::fabs(b - a) > epsilon_)
111  {
112  if (useInverseQuadraticInterpolation(fa, fb, lastFb))
113  {
114  s = calculateInverseQuadraticInterpolation(a, b, lastB, fa, fb, lastFb);
115  }
116  else
117  {
118  s = calculateSecant(a, b, fa, fb);
119  }
120 
121  if (useBisection(bisection, b, lastB, penultimateB, s))
122  {
123  s = calculateBisection(a, b);
124  bisection = true;
125  }
126  else
127  {
128  bisection = false;
129  }
130 
131  fs = f_(s);
132  penultimateB = lastB;
133  lastB = b;
134 
135  if (fa * fs < 0)
136  {
137  b = s;
138  }
139  else
140  {
141  a = s;
142  }
143 
144  fa = f_(a);
145  lastFb = fb;
146  fb = f_(b);
147  checkAndFixAlgorithmCriteria(a, b, fa, fb);
148 
150  }
151 
152  return fb < fs ? b : s;
153  }
154 
155  private:
156  //============================================================================
157  const std::function<double(double)> f_;
158 
159  //============================================================================
160  // Method Description:
167  static double calculateBisection(const double a, const double b) noexcept
168  {
169  return 0.5 * (a + b);
170  }
171 
172  //============================================================================
173  // Method Description:
182  static double calculateSecant(const double a, const double b, const double fa, const double fb) noexcept
183  {
184  // No need to check division by 0, in this case the method returns NAN which is taken care by
185  // useSecantMethod method
186  return b - fb * (b - a) / (fb - fa);
187  }
188 
189  //============================================================================
190  // Method Description:
201  static double calculateInverseQuadraticInterpolation(const double a,
202  const double b,
203  const double lastB,
204  const double fa,
205  const double fb,
206  const double lastFb) noexcept
207  {
208  return a * fb * lastFb / ((fa - fb) * (fa - lastFb)) + b * fa * lastFb / ((fb - fa) * (fb - lastFb)) +
209  lastB * fa * fb / ((lastFb - fa) * (lastFb - fb));
210  }
211 
212  //============================================================================
213  // Method Description:
221  static bool useInverseQuadraticInterpolation(const double fa, const double fb, const double lastFb) noexcept
222  {
223  return !utils::essentiallyEqual(fa, lastFb) && utils::essentiallyEqual(fb, lastFb);
224  }
225 
226  //============================================================================
227  // Method Description:
235  static void checkAndFixAlgorithmCriteria(double &a, double &b, double &fa, double &fb) noexcept
236  {
237  // Algorithm works in range [a,b] if criteria f(a)*f(b) < 0 and f(a) > f(b) is fulfilled
238  if (std::fabs(fa) < std::fabs(fb))
239  {
240  std::swap(a, b);
241  std::swap(fa, fb);
242  }
243  }
244 
245  //============================================================================
246  // Method Description:
256  [[nodiscard]] bool useBisection(const bool bisection,
257  const double b,
258  const double lastB,
259  const double penultimateB,
260  const double s) const noexcept
261  {
262  const double DELTA = epsilon_ + std::numeric_limits<double>::min();
263 
264  return (bisection &&
265  std::fabs(s - b) >=
266  0.5 * std::fabs(b - lastB)) || // Bisection was used in last step but |s-b|>=|b-lastB|/2 <-
267  // Interpolation step would be to rough, so still use bisection
268  (!bisection && std::fabs(s - b) >=
269  0.5 * std::fabs(lastB - penultimateB)) || // Interpolation was used in last step
270  // but |s-b|>=|lastB-penultimateB|/2 <-
271  // Interpolation step would be to small
272  (bisection &&
273  std::fabs(b - lastB) < DELTA) || // If last iteration was using bisection and difference between
274  // b and lastB is < delta use bisection for next iteration
275  (!bisection && std::fabs(lastB - penultimateB) <
276  DELTA); // If last iteration was using interpolation but difference between
277  // lastB ond penultimateB is < delta use biscetion for next iteration
278  }
279  };
280 } // namespace nc::roots
static constexpr dtype max() noexcept
Definition: DtypeInfo.hpp:110
Definition: Brent.hpp:51
Brent(const double epsilon, const uint32 maxNumIterations, std::function< double(double)> f) noexcept
Definition: Brent.hpp:74
double solve(double a, double b)
Definition: Brent.hpp:94
~Brent() override=default
Brent(const double epsilon, std::function< double(double)> f) noexcept
Definition: Brent.hpp:60
ABC for iteration classes to derive from.
Definition: Iteration.hpp:46
Iteration(double epsilon) noexcept
Definition: Iteration.hpp:54
const double epsilon_
Definition: Iteration.hpp:116
void resetNumberOfIterations() noexcept
Definition: Iteration.hpp:94
void incrementNumberOfIterations()
Definition: Iteration.hpp:105
dtype f(GeneratorType &generator, dtype inDofN, dtype inDofD)
Definition: f.hpp:56
Definition: Bisection.hpp:43
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:48
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
std::uint32_t uint32
Definition: Types.hpp:40