NumCpp  2.11.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
SVDClass.hpp
Go to the documentation of this file.
1 #pragma once
30 
31 #include <cmath>
32 #include <limits>
33 #include <string>
34 
36 #include "NumCpp/Core/Types.hpp"
37 #include "NumCpp/NdArray.hpp"
39 
40 namespace nc::linalg
41 {
42  // =============================================================================
43  // Class Description:
46  class SVD
47  {
48  public:
49  // =============================================================================
50  // Description:
55  explicit SVD(const NdArray<double>& inMatrix) :
56  m_(inMatrix.shape().rows),
57  n_(inMatrix.shape().cols),
58  u_(inMatrix),
59  v_(n_, n_),
60  s_(1, n_),
61  eps_(std::numeric_limits<double>::epsilon())
62  {
63  decompose();
64  reorder();
65  tsh_ = 0.5 * std::sqrt(m_ + n_ + 1.) * s_.front() * eps_;
66  }
67 
68  // =============================================================================
69  // Description:
74  const NdArray<double>& u() noexcept
75  {
76  return u_;
77  }
78 
79  // =============================================================================
80  // Description:
85  const NdArray<double>& v() noexcept
86  {
87  return v_;
88  }
89 
90  // =============================================================================
91  // Description:
96  const NdArray<double>& s() noexcept
97  {
98  return s_;
99  }
100 
101  // =============================================================================
102  // Description:
110  NdArray<double> solve(const NdArray<double>& inInput, double inThresh = -1.)
111  {
112  double ss{};
113 
114  if (inInput.size() != m_)
115  {
116  THROW_INVALID_ARGUMENT_ERROR("bad sizes.");
117  }
118 
119  NdArray<double> returnArray(1, n_);
120 
121  NdArray<double> tmp(1, n_);
122 
123  tsh_ = (inThresh >= 0. ? inThresh : 0.5 * sqrt(m_ + n_ + 1.) * s_.front() * eps_);
124 
125  for (uint32 j = 0; j < n_; j++)
126  {
127  ss = 0.;
128  if (s_[j] > tsh_)
129  {
130  for (uint32 i = 0; i < m_; i++)
131  {
132  ss += u_(i, j) * inInput[i];
133  }
134  ss /= s_[j];
135  }
136  tmp[j] = ss;
137  }
138 
139  for (uint32 j = 0; j < n_; j++)
140  {
141  ss = 0.;
142  for (uint32 jj = 0; jj < n_; jj++)
143  {
144  ss += v_(j, jj) * tmp[jj];
145  }
146 
147  returnArray[j] = ss;
148  }
149 
150  return returnArray;
151  }
152 
153  private:
154  // =============================================================================
155  // Description:
163  static double SIGN(double inA, double inB) noexcept
164  {
165  return inB >= 0 ? (inA >= 0 ? inA : -inA) : (inA >= 0 ? -inA : inA);
166  }
167 
168  // =============================================================================
169  // Description:
172  void decompose()
173  {
174  bool flag{};
175  uint32 i{};
176  uint32 its{};
177  uint32 j{};
178  uint32 jj{};
179  uint32 k{};
180  uint32 l{};
181  uint32 nm{};
182 
183  double anorm{};
184  double c{};
185  double f{};
186  double g{};
187  double h{};
188  double ss{};
189  double scale{};
190  double x{};
191  double y{};
192  double z{};
193 
194  NdArray<double> rv1(n_, 1);
195 
196  for (i = 0; i < n_; ++i)
197  {
198  l = i + 2;
199  rv1[i] = scale * g;
200  g = ss = scale = 0.;
201 
202  if (i < m_)
203  {
204  for (k = i; k < m_; ++k)
205  {
206  scale += std::abs(u_(k, i));
207  }
208 
209  if (!utils::essentiallyEqual(scale, 0.))
210  {
211  for (k = i; k < m_; ++k)
212  {
213  u_(k, i) /= scale;
214  ss += u_(k, i) * u_(k, i);
215  }
216 
217  f = u_(i, i);
218  g = -SIGN(std::sqrt(ss), f);
219  h = f * g - ss;
220  u_(i, i) = f - g;
221 
222  for (j = l - 1; j < n_; ++j)
223  {
224  for (ss = 0., k = i; k < m_; ++k)
225  {
226  ss += u_(k, i) * u_(k, j);
227  }
228 
229  f = ss / h;
230 
231  for (k = i; k < m_; ++k)
232  {
233  u_(k, j) += f * u_(k, i);
234  }
235  }
236 
237  for (k = i; k < m_; ++k)
238  {
239  u_(k, i) *= scale;
240  }
241  }
242  }
243 
244  s_[i] = scale * g;
245  g = ss = scale = 0.;
246 
247  if (i + 1 <= m_ && i + 1 != n_)
248  {
249  for (k = l - 1; k < n_; ++k)
250  {
251  scale += std::abs(u_(i, k));
252  }
253 
254  if (!utils::essentiallyEqual(scale, 0.))
255  {
256  for (k = l - 1; k < n_; ++k)
257  {
258  u_(i, k) /= scale;
259  ss += u_(i, k) * u_(i, k);
260  }
261 
262  f = u_(i, l - 1);
263  g = -SIGN(std::sqrt(ss), f);
264  h = f * g - ss;
265  u_(i, l - 1) = f - g;
266 
267  for (k = l - 1; k < n_; ++k)
268  {
269  rv1[k] = u_(i, k) / h;
270  }
271 
272  for (j = l - 1; j < m_; ++j)
273  {
274  for (ss = 0., k = l - 1; k < n_; ++k)
275  {
276  ss += u_(j, k) * u_(i, k);
277  }
278 
279  for (k = l - 1; k < n_; ++k)
280  {
281  u_(j, k) += ss * rv1[k];
282  }
283  }
284 
285  for (k = l - 1; k < n_; ++k)
286  {
287  u_(i, k) *= scale;
288  }
289  }
290  }
291 
292  anorm = std::max(anorm, (std::abs(s_[i]) + std::abs(rv1[i])));
293  }
294 
295  for (i = n_ - 1; i != static_cast<uint32>(-1); --i)
296  {
297  if (i < n_ - 1)
298  {
299  if (!utils::essentiallyEqual(g, 0.))
300  {
301  for (j = l; j < n_; ++j)
302  {
303  v_(j, i) = (u_(i, j) / u_(i, l)) / g;
304  }
305 
306  for (j = l; j < n_; ++j)
307  {
308  for (ss = 0., k = l; k < n_; ++k)
309  {
310  ss += u_(i, k) * v_(k, j);
311  }
312 
313  for (k = l; k < n_; ++k)
314  {
315  v_(k, j) += ss * v_(k, i);
316  }
317  }
318  }
319 
320  for (j = l; j < n_; ++j)
321  {
322  v_(i, j) = v_(j, i) = 0.;
323  }
324  }
325 
326  v_(i, i) = 1.;
327  g = rv1[i];
328  l = i;
329  }
330 
331  for (i = std::min(m_, n_) - 1; i != static_cast<uint32>(-1); --i)
332  {
333  l = i + 1;
334  g = s_[i];
335 
336  for (j = l; j < n_; ++j)
337  {
338  u_(i, j) = 0.;
339  }
340 
341  if (!utils::essentiallyEqual(g, 0.))
342  {
343  g = 1. / g;
344 
345  for (j = l; j < n_; ++j)
346  {
347  for (ss = 0., k = l; k < m_; ++k)
348  {
349  ss += u_(k, i) * u_(k, j);
350  }
351 
352  f = (ss / u_(i, i)) * g;
353 
354  for (k = i; k < m_; ++k)
355  {
356  u_(k, j) += f * u_(k, i);
357  }
358  }
359 
360  for (j = i; j < m_; ++j)
361  {
362  u_(j, i) *= g;
363  }
364  }
365  else
366  {
367  for (j = i; j < m_; ++j)
368  {
369  u_(j, i) = 0.;
370  }
371  }
372 
373  ++u_(i, i);
374  }
375 
376  for (k = n_ - 1; k != static_cast<uint32>(-1); --k)
377  {
378  for (its = 0; its < 30; ++its)
379  {
380  flag = true;
381  for (l = k; l != static_cast<uint32>(-1); --l)
382  {
383  nm = l - 1;
384  if (l == 0 || std::abs(rv1[l]) <= eps_ * anorm)
385  {
386  flag = false;
387  break;
388  }
389 
390  if (std::abs(s_[nm]) <= eps_ * anorm)
391  {
392  break;
393  }
394  }
395 
396  if (flag)
397  {
398  c = 0.;
399  ss = 1.;
400  for (i = l; i < k + 1; ++i)
401  {
402  f = ss * rv1[i];
403  rv1[i] = c * rv1[i];
404 
405  if (std::abs(f) <= eps_ * anorm)
406  {
407  break;
408  }
409 
410  g = s_[i];
411  h = pythag(f, g);
412  s_[i] = h;
413  h = 1. / h;
414  c = g * h;
415  ss = -f * h;
416 
417  for (j = 0; j < m_; ++j)
418  {
419  y = u_(j, nm);
420  z = u_(j, i);
421  u_(j, nm) = y * c + z * ss;
422  u_(j, i) = z * c - y * ss;
423  }
424  }
425  }
426 
427  z = s_[k];
428  if (l == k)
429  {
430  if (z < 0.)
431  {
432  s_[k] = -z;
433  for (j = 0; j < n_; ++j)
434  {
435  v_(j, k) = -v_(j, k);
436  }
437  }
438  break;
439  }
440 
441  if (its == 29)
442  {
443  THROW_INVALID_ARGUMENT_ERROR("no convergence in 30 svdcmp iterations");
444  }
445 
446  x = s_[l];
447  nm = k - 1;
448  y = s_[nm];
449  g = rv1[nm];
450  h = rv1[k];
451  f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2. * h * y);
452  g = pythag(f, 1.);
453  f = ((x - z) * (x + z) + h * ((y / (f + SIGN(g, f))) - h)) / x;
454  c = ss = 1.;
455 
456  for (j = l; j <= nm; j++)
457  {
458  i = j + 1;
459  g = rv1[i];
460  y = s_[i];
461  h = ss * g;
462  g = c * g;
463  z = pythag(f, h);
464  rv1[j] = z;
465  c = f / z;
466  ss = h / z;
467  f = x * c + g * ss;
468  g = g * c - x * ss;
469  h = y * ss;
470  y *= c;
471 
472  for (jj = 0; jj < n_; ++jj)
473  {
474  x = v_(jj, j);
475  z = v_(jj, i);
476  v_(jj, j) = x * c + z * ss;
477  v_(jj, i) = z * c - x * ss;
478  }
479 
480  z = pythag(f, h);
481  s_[j] = z;
482 
483  if (!utils::essentiallyEqual(z, 0.))
484  {
485  z = 1. / z;
486  c = f * z;
487  ss = h * z;
488  }
489 
490  f = c * g + ss * y;
491  x = c * y - ss * g;
492 
493  for (jj = 0; jj < m_; ++jj)
494  {
495  y = u_(jj, j);
496  z = u_(jj, i);
497  u_(jj, j) = y * c + z * ss;
498  u_(jj, i) = z * c - y * ss;
499  }
500  }
501  rv1[l] = 0.;
502  rv1[k] = f;
503  s_[k] = x;
504  }
505  }
506  }
507 
508  // =============================================================================
509  // Description:
512  void reorder()
513  {
514  uint32 i = 0;
515  uint32 j = 0;
516  uint32 k = 0;
517  uint32 ss = 0;
518  uint32 inc = 1;
519 
520  double sw{};
521  NdArray<double> su(m_, 1);
522  NdArray<double> sv(n_, 1);
523 
524  do
525  {
526  inc *= 3;
527  ++inc;
528  } while (inc <= n_);
529 
530  do
531  {
532  inc /= 3;
533  for (i = inc; i < n_; ++i)
534  {
535  sw = s_[i];
536  for (k = 0; k < m_; ++k)
537  {
538  su[k] = u_(k, i);
539  }
540 
541  for (k = 0; k < n_; ++k)
542  {
543  sv[k] = v_(k, i);
544  }
545 
546  j = i;
547  while (s_[j - inc] < sw)
548  {
549  s_[j] = s_[j - inc];
550 
551  for (k = 0; k < m_; ++k)
552  {
553  u_(k, j) = u_(k, j - inc);
554  }
555 
556  for (k = 0; k < n_; ++k)
557  {
558  v_(k, j) = v_(k, j - inc);
559  }
560 
561  j -= inc;
562 
563  if (j < inc)
564  {
565  break;
566  }
567  }
568 
569  s_[j] = sw;
570 
571  for (k = 0; k < m_; ++k)
572  {
573  u_(k, j) = su[k];
574  }
575 
576  for (k = 0; k < n_; ++k)
577  {
578  v_(k, j) = sv[k];
579  }
580  }
581  } while (inc > 1);
582 
583  for (k = 0; k < n_; ++k)
584  {
585  ss = 0;
586 
587  for (i = 0; i < m_; i++)
588  {
589  if (u_(i, k) < 0.)
590  {
591  ss++;
592  }
593  }
594 
595  for (j = 0; j < n_; ++j)
596  {
597  if (v_(j, k) < 0.)
598  {
599  ss++;
600  }
601  }
602 
603  if (ss > (m_ + n_) / 2)
604  {
605  for (i = 0; i < m_; ++i)
606  {
607  u_(i, k) = -u_(i, k);
608  }
609 
610  for (j = 0; j < n_; ++j)
611  {
612  v_(j, k) = -v_(j, k);
613  }
614  }
615  }
616  }
617 
618  // =============================================================================
619  // Description:
627  static double pythag(double inA, double inB) noexcept
628  {
629  const double absa = std::abs(inA);
630  const double absb = std::abs(inB);
631  return (absa > absb
632  ? absa * std::sqrt(1. + utils::sqr(absb / absa))
633  : (utils::essentiallyEqual(absb, 0.) ? 0. : absb * std::sqrt(1. + utils::sqr(absa / absb))));
634  }
635 
636  private:
637  // ===============================Attributes====================================
638  const uint32 m_{};
639  const uint32 n_{};
640  NdArray<double> u_{};
641  NdArray<double> v_{};
642  NdArray<double> s_{};
643  double eps_{};
644  double tsh_{};
645  };
646 } // namespace nc::linalg
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
size_type size() const noexcept
Definition: NdArrayCore.hpp:4477
const_reference front() const noexcept
Definition: NdArrayCore.hpp:2813
Definition: SVDClass.hpp:47
const NdArray< double > & u() noexcept
Definition: SVDClass.hpp:74
const NdArray< double > & s() noexcept
Definition: SVDClass.hpp:96
NdArray< double > solve(const NdArray< double > &inInput, double inThresh=-1.)
Definition: SVDClass.hpp:110
const NdArray< double > & v() noexcept
Definition: SVDClass.hpp:85
SVD(const NdArray< double > &inMatrix)
Definition: SVDClass.hpp:55
constexpr auto j
Definition: Core/Constants.hpp:42
constexpr double c
speed of light
Definition: Core/Constants.hpp:36
Definition: cholesky.hpp:41
dtype f(GeneratorType &generator, dtype inDofN, dtype inDofD)
Definition: f.hpp:56
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:48
constexpr dtype sqr(dtype inValue) noexcept
Definition: sqr.hpp:42
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44
auto abs(dtype inValue) noexcept
Definition: abs.hpp:49
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
auto sqrt(dtype inValue) noexcept
Definition: sqrt.hpp:48
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40