NumCpp  2.7.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
SVDClass.hpp
Go to the documentation of this file.
1 #pragma once
30 
32 #include "NumCpp/Core/Types.hpp"
33 #include "NumCpp/NdArray.hpp"
34 
35 #include <cmath>
36 #include <limits>
37 #include <string>
38 
39 namespace nc
40 {
41  namespace linalg
42  {
43  // =============================================================================
44  // Class Description:
47  class SVD
48  {
49  public:
50  // =============================================================================
51  // Description:
56  explicit SVD(const NdArray<double>& inMatrix) :
57  m_(inMatrix.shape().rows),
58  n_(inMatrix.shape().cols),
59  u_(inMatrix),
60  v_(n_, n_),
61  s_(1, n_),
62  eps_(std::numeric_limits<double>::epsilon())
63  {
64  decompose();
65  reorder();
66  tsh_ = 0.5 * std::sqrt(m_ + n_ + 1.) * s_.front() * eps_;
67  }
68 
69  // =============================================================================
70  // Description:
75  const NdArray<double>& u() noexcept
76  {
77  return u_;
78  }
79 
80  // =============================================================================
81  // Description:
86  const NdArray<double>& v() noexcept
87  {
88  return v_;
89  }
90 
91  // =============================================================================
92  // Description:
97  const NdArray<double>& s() noexcept
98  {
99  return s_;
100  }
101 
102  // =============================================================================
103  // Description:
111  NdArray<double> solve(const NdArray<double>& inInput, double inThresh = -1.0)
112  {
113  double ss = 0.0;
114 
115  if (inInput.size() != m_)
116  {
117  THROW_INVALID_ARGUMENT_ERROR("bad sizes.");
118  }
119 
120  NdArray<double> returnArray(1, n_);
121 
122  NdArray<double> tmp(1, n_);
123 
124  tsh_ = (inThresh >= 0. ? inThresh : 0.5 * sqrt(m_ + n_ + 1.) * s_.front() * eps_);
125 
126  for (uint32 j = 0; j < n_; j++)
127  {
128  ss = 0.0;
129  if (s_[j] > tsh_)
130  {
131  for (uint32 i = 0; i < m_; i++)
132  {
133  ss += u_(i, j) * inInput[i];
134  }
135  ss /= s_[j];
136  }
137  tmp[j] = ss;
138  }
139 
140  for (uint32 j = 0; j < n_; j++)
141  {
142  ss = 0.0;
143  for (uint32 jj = 0; jj < n_; jj++)
144  {
145  ss += v_(j, jj) * tmp[jj];
146  }
147 
148  returnArray[j] = ss;
149  }
150 
151  return returnArray;
152  }
153 
154  private:
155  // =============================================================================
156  // Description:
164  static double SIGN(double inA, double inB) noexcept
165  {
166  return inB >= 0 ? (inA >= 0 ? inA : -inA) : (inA >= 0 ? -inA : inA);
167  }
168 
169  // =============================================================================
170  // Description:
173  void decompose()
174  {
175  bool flag = true;
176  uint32 i = 0;
177  uint32 its = 0;
178  uint32 j = 0;
179  uint32 jj = 0;
180  uint32 k = 0;
181  uint32 l = 0;
182  uint32 nm = 0;
183 
184  double anorm = 0.0;
185  double c = 0.0;
186  double f = 0.0;
187  double g = 0.0;
188  double h = 0.0;
189  double ss = 0.0;
190  double scale = 0.0;
191  double x = 0.0;
192  double y = 0.0;
193  double z = 0.0;
194 
195  NdArray<double> rv1(n_, 1);
196 
197  for (i = 0; i < n_; ++i)
198  {
199  l = i + 2;
200  rv1[i] = scale * g;
201  g = ss = scale = 0.0;
202 
203  if (i < m_)
204  {
205  for (k = i; k < m_; ++k)
206  {
207  scale += std::abs(u_(k, i));
208  }
209 
210  if (scale != 0.0)
211  {
212  for (k = i; k < m_; ++k)
213  {
214  u_(k, i) /= scale;
215  ss += u_(k, i) * u_(k, i);
216  }
217 
218  f = u_(i, i);
219  g = -SIGN(std::sqrt(ss), f);
220  h = f * g - ss;
221  u_(i, i) = f - g;
222 
223  for (j = l - 1; j < n_; ++j)
224  {
225  for (ss = 0.0, k = i; k < m_; ++k)
226  {
227  ss += u_(k, i) * u_(k, j);
228  }
229 
230  f = ss / h;
231 
232  for (k = i; k < m_; ++k)
233  {
234  u_(k, j) += f * u_(k, i);
235  }
236  }
237 
238  for (k = i; k < m_; ++k)
239  {
240  u_(k, i) *= scale;
241  }
242  }
243  }
244 
245  s_[i] = scale * g;
246  g = ss = scale = 0.0;
247 
248  if (i + 1 <= m_ && i + 1 != n_)
249  {
250  for (k = l - 1; k < n_; ++k)
251  {
252  scale += std::abs(u_(i, k));
253  }
254 
255  if (scale != 0.0)
256  {
257  for (k = l - 1; k < n_; ++k)
258  {
259  u_(i, k) /= scale;
260  ss += u_(i, k) * u_(i, k);
261  }
262 
263  f = u_(i, l - 1);
264  g = -SIGN(std::sqrt(ss), f);
265  h = f * g - ss;
266  u_(i, l - 1) = f - g;
267 
268  for (k = l - 1; k < n_; ++k)
269  {
270  rv1[k] = u_(i, k) / h;
271  }
272 
273  for (j = l - 1; j < m_; ++j)
274  {
275  for (ss = 0.0, k = l - 1; k < n_; ++k)
276  {
277  ss += u_(j, k) * u_(i, k);
278  }
279 
280  for (k = l - 1; k < n_; ++k)
281  {
282  u_(j, k) += ss * rv1[k];
283  }
284  }
285 
286  for (k = l - 1; k < n_; ++k)
287  {
288  u_(i, k) *= scale;
289  }
290  }
291  }
292 
293  anorm = std::max(anorm, (std::abs(s_[i]) + std::abs(rv1[i])));
294  }
295 
296  for (i = n_ - 1; i != static_cast<uint32>(-1); --i)
297  {
298  if (i < n_ - 1)
299  {
300  if (g != 0.0)
301  {
302  for (j = l; j < n_; ++j)
303  {
304  v_(j, i) = (u_(i, j) / u_(i, l)) / g;
305  }
306 
307  for (j = l; j < n_; ++j)
308  {
309  for (ss = 0.0, k = l; k < n_; ++k)
310  {
311  ss += u_(i, k) * v_(k, j);
312  }
313 
314  for (k = l; k < n_; ++k)
315  {
316  v_(k, j) += ss * v_(k, i);
317  }
318  }
319  }
320 
321  for (j = l; j < n_; ++j)
322  {
323  v_(i, j) = v_(j, i) = 0.0;
324  }
325  }
326 
327  v_(i, i) = 1.0;
328  g = rv1[i];
329  l = i;
330  }
331 
332  for (i = std::min(m_, n_) - 1; i != static_cast<uint32>(-1); --i)
333  {
334  l = i + 1;
335  g = s_[i];
336 
337  for (j = l; j < n_; ++j)
338  {
339  u_(i, j) = 0.0;
340  }
341 
342  if (g != 0.0)
343  {
344  g = 1.0 / g;
345 
346  for (j = l; j < n_; ++j)
347  {
348  for (ss = 0.0, k = l; k < m_; ++k)
349  {
350  ss += u_(k, i) * u_(k, j);
351  }
352 
353  f = (ss / u_(i, i)) * g;
354 
355  for (k = i; k < m_; ++k)
356  {
357  u_(k, j) += f * u_(k, i);
358  }
359  }
360 
361  for (j = i; j < m_; ++j)
362  {
363  u_(j, i) *= g;
364  }
365 
366  }
367  else
368  {
369  for (j = i; j < m_; ++j)
370  {
371  u_(j, i) = 0.0;
372  }
373  }
374 
375  ++u_(i, i);
376  }
377 
378  for (k = n_ - 1; k != static_cast<uint32>(-1); --k)
379  {
380  for (its = 0; its < 30; ++its)
381  {
382  flag = true;
383  for (l = k; l != static_cast<uint32>(-1); --l)
384  {
385  nm = l - 1;
386  if (l == 0 || std::abs(rv1[l]) <= eps_ * anorm)
387  {
388  flag = false;
389  break;
390  }
391 
392  if (std::abs(s_[nm]) <= eps_ * anorm)
393  {
394  break;
395  }
396  }
397 
398  if (flag)
399  {
400  c = 0.0;
401  ss = 1.0;
402  for (i = l; i < k + 1; ++i)
403  {
404  f = ss * rv1[i];
405  rv1[i] = c * rv1[i];
406 
407  if (std::abs(f) <= eps_ * anorm)
408  {
409  break;
410  }
411 
412  g = s_[i];
413  h = pythag(f, g);
414  s_[i] = h;
415  h = 1.0 / h;
416  c = g * h;
417  ss = -f * h;
418 
419  for (j = 0; j < m_; ++j)
420  {
421  y = u_(j, nm);
422  z = u_(j, i);
423  u_(j, nm) = y * c + z * ss;
424  u_(j, i) = z * c - y * ss;
425  }
426  }
427  }
428 
429  z = s_[k];
430  if (l == k)
431  {
432  if (z < 0.0)
433  {
434  s_[k] = -z;
435  for (j = 0; j < n_; ++j)
436  {
437  v_(j, k) = -v_(j, k);
438  }
439  }
440  break;
441  }
442 
443  if (its == 29)
444  {
445  THROW_INVALID_ARGUMENT_ERROR("no convergence in 30 svdcmp iterations");
446  }
447 
448  x = s_[l];
449  nm = k - 1;
450  y = s_[nm];
451  g = rv1[nm];
452  h = rv1[k];
453  f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y);
454  g = pythag(f, 1.0);
455  f = ((x - z) * (x + z) + h * ((y / (f + SIGN(g, f))) - h)) / x;
456  c = ss = 1.0;
457 
458  for (j = l; j <= nm; j++)
459  {
460  i = j + 1;
461  g = rv1[i];
462  y = s_[i];
463  h = ss * g;
464  g = c * g;
465  z = pythag(f, h);
466  rv1[j] = z;
467  c = f / z;
468  ss = h / z;
469  f = x * c + g * ss;
470  g = g * c - x * ss;
471  h = y * ss;
472  y *= c;
473 
474  for (jj = 0; jj < n_; ++jj)
475  {
476  x = v_(jj, j);
477  z = v_(jj, i);
478  v_(jj, j) = x * c + z * ss;
479  v_(jj, i) = z * c - x * ss;
480  }
481 
482  z = pythag(f, h);
483  s_[j] = z;
484 
485  if (z != 0.0)
486  {
487  z = 1.0 / z;
488  c = f * z;
489  ss = h * z;
490  }
491 
492  f = c * g + ss * y;
493  x = c * y - ss * g;
494 
495  for (jj = 0; jj < m_; ++jj)
496  {
497  y = u_(jj, j);
498  z = u_(jj, i);
499  u_(jj, j) = y * c + z * ss;
500  u_(jj, i) = z * c - y * ss;
501  }
502  }
503  rv1[l] = 0.0;
504  rv1[k] = f;
505  s_[k] = x;
506  }
507  }
508  }
509 
510  // =============================================================================
511  // Description:
514  void reorder()
515  {
516  uint32 i = 0;
517  uint32 j = 0;
518  uint32 k = 0;
519  uint32 ss = 0;
520  uint32 inc = 1;
521 
522  double sw = 0.0;
523  NdArray<double> su(m_, 1);
524  NdArray<double> sv(n_, 1);
525 
526  do
527  {
528  inc *= 3;
529  ++inc;
530  } while (inc <= n_);
531 
532  do
533  {
534  inc /= 3;
535  for (i = inc; i < n_; ++i)
536  {
537  sw = s_[i];
538  for (k = 0; k < m_; ++k)
539  {
540  su[k] = u_(k, i);
541  }
542 
543  for (k = 0; k < n_; ++k)
544  {
545  sv[k] = v_(k, i);
546  }
547 
548  j = i;
549  while (s_[j - inc] < sw)
550  {
551  s_[j] = s_[j - inc];
552 
553  for (k = 0; k < m_; ++k)
554  {
555  u_(k, j) = u_(k, j - inc);
556  }
557 
558  for (k = 0; k < n_; ++k)
559  {
560  v_(k, j) = v_(k, j - inc);
561  }
562 
563  j -= inc;
564 
565  if (j < inc)
566  {
567  break;
568  }
569  }
570 
571  s_[j] = sw;
572 
573  for (k = 0; k < m_; ++k)
574  {
575  u_(k, j) = su[k];
576  }
577 
578  for (k = 0; k < n_; ++k)
579  {
580  v_(k, j) = sv[k];
581  }
582 
583  }
584  } while (inc > 1);
585 
586  for (k = 0; k < n_; ++k)
587  {
588  ss = 0;
589 
590  for (i = 0; i < m_; i++)
591  {
592  if (u_(i, k) < 0.)
593  {
594  ss++;
595  }
596  }
597 
598  for (j = 0; j < n_; ++j)
599  {
600  if (v_(j, k) < 0.)
601  {
602  ss++;
603  }
604  }
605 
606  if (ss > (m_ + n_) / 2)
607  {
608  for (i = 0; i < m_; ++i)
609  {
610  u_(i, k) = -u_(i, k);
611  }
612 
613  for (j = 0; j < n_; ++j)
614  {
615  v_(j, k) = -v_(j, k);
616  }
617  }
618  }
619  }
620 
621  // =============================================================================
622  // Description:
630  static double pythag(double inA, double inB) noexcept
631  {
632  const double absa = std::abs(inA);
633  const double absb = std::abs(inB);
634  return (absa > absb ? absa * std::sqrt(1.0 + utils::sqr(absb / absa)) :
635  (absb == 0.0 ? 0.0 : absb * std::sqrt(1.0 + utils::sqr(absa / absb))));
636  }
637 
638  private:
639  // ===============================Attributes====================================
640  const uint32 m_;
641  const uint32 n_;
642  NdArray<double> u_;
643  NdArray<double> v_;
644  NdArray<double> s_;
645  double eps_;
646  double tsh_;
647  };
648  } // namespace linalg
649 } // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
size_type size() const noexcept
Definition: NdArrayCore.hpp:4296
const_reference front() const noexcept
Definition: NdArrayCore.hpp:2772
Definition: SVDClass.hpp:48
const NdArray< double > & u() noexcept
Definition: SVDClass.hpp:75
const NdArray< double > & s() noexcept
Definition: SVDClass.hpp:97
const NdArray< double > & v() noexcept
Definition: SVDClass.hpp:86
NdArray< double > solve(const NdArray< double > &inInput, double inThresh=-1.0)
Definition: SVDClass.hpp:111
SVD(const NdArray< double > &inMatrix)
Definition: SVDClass.hpp:56
constexpr auto j
Definition: Constants.hpp:45
constexpr double c
speed of light
Definition: Constants.hpp:40
dtype f(dtype inDofN, dtype inDofD)
Definition: f.hpp:55
constexpr dtype sqr(dtype inValue) noexcept
Definition: sqr.hpp:44
Definition: Coordinate.hpp:45
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44
auto abs(dtype inValue) noexcept
Definition: abs.hpp:49
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
auto sqrt(dtype inValue) noexcept
Definition: sqrt.hpp:48
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40