NumCpp  2.5.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
SVDClass.hpp
Go to the documentation of this file.
1 #pragma once
30 
32 #include "NumCpp/Core/Types.hpp"
33 #include "NumCpp/NdArray.hpp"
34 
35 #include <cmath>
36 #include <limits>
37 #include <string>
38 
39 namespace nc
40 {
41  namespace linalg
42  {
43  // =============================================================================
44  // Class Description:
47  class SVD
48  {
49  public:
50  // =============================================================================
51  // Description:
57  explicit SVD(const NdArray<double>& inMatrix) :
58  m_(inMatrix.shape().rows),
59  n_(inMatrix.shape().cols),
60  u_(inMatrix),
61  v_(n_, n_),
62  s_(1, n_),
63  eps_(std::numeric_limits<double>::epsilon())
64  {
65  decompose();
66  reorder();
67  tsh_ = 0.5 * std::sqrt(m_ + n_ + 1.) * s_.front() * eps_;
68  }
69 
70  // =============================================================================
71  // Description:
77  const NdArray<double>& u() noexcept
78  {
79  return u_;
80  }
81 
82  // =============================================================================
83  // Description:
89  const NdArray<double>& v() noexcept
90  {
91  return v_;
92  }
93 
94  // =============================================================================
95  // Description:
101  const NdArray<double>& s() noexcept
102  {
103  return s_;
104  }
105 
106  // =============================================================================
107  // Description:
116  NdArray<double> solve(const NdArray<double>& inInput, double inThresh = -1.0)
117  {
118  double ss = 0.0;
119 
120  if (inInput.size() != m_)
121  {
122  THROW_INVALID_ARGUMENT_ERROR("bad sizes.");
123  }
124 
125  NdArray<double> returnArray(1, n_);
126 
127  NdArray<double> tmp(1, n_);
128 
129  tsh_ = (inThresh >= 0. ? inThresh : 0.5 * sqrt(m_ + n_ + 1.) * s_.front() * eps_);
130 
131  for (uint32 j = 0; j < n_; j++)
132  {
133  ss = 0.0;
134  if (s_[j] > tsh_)
135  {
136  for (uint32 i = 0; i < m_; i++)
137  {
138  ss += u_(i, j) * inInput[i];
139  }
140  ss /= s_[j];
141  }
142  tmp[j] = ss;
143  }
144 
145  for (uint32 j = 0; j < n_; j++)
146  {
147  ss = 0.0;
148  for (uint32 jj = 0; jj < n_; jj++)
149  {
150  ss += v_(j, jj) * tmp[jj];
151  }
152 
153  returnArray[j] = ss;
154  }
155 
156  return returnArray;
157  }
158 
159  private:
160  // =============================================================================
161  // Description:
170  static double SIGN(double inA, double inB) noexcept
171  {
172  return inB >= 0 ? (inA >= 0 ? inA : -inA) : (inA >= 0 ? -inA : inA);
173  }
174 
175  // =============================================================================
176  // Description:
179  void decompose()
180  {
181  bool flag = true;
182  uint32 i = 0;
183  uint32 its = 0;
184  uint32 j = 0;
185  uint32 jj = 0;
186  uint32 k = 0;
187  uint32 l = 0;
188  uint32 nm = 0;
189 
190  double anorm = 0.0;
191  double c = 0.0;
192  double f = 0.0;
193  double g = 0.0;
194  double h = 0.0;
195  double ss = 0.0;
196  double scale = 0.0;
197  double x = 0.0;
198  double y = 0.0;
199  double z = 0.0;
200 
201  NdArray<double> rv1(n_, 1);
202 
203  for (i = 0; i < n_; ++i)
204  {
205  l = i + 2;
206  rv1[i] = scale * g;
207  g = ss = scale = 0.0;
208 
209  if (i < m_)
210  {
211  for (k = i; k < m_; ++k)
212  {
213  scale += std::abs(u_(k, i));
214  }
215 
216  if (scale != 0.0)
217  {
218  for (k = i; k < m_; ++k)
219  {
220  u_(k, i) /= scale;
221  ss += u_(k, i) * u_(k, i);
222  }
223 
224  f = u_(i, i);
225  g = -SIGN(std::sqrt(ss), f);
226  h = f * g - ss;
227  u_(i, i) = f - g;
228 
229  for (j = l - 1; j < n_; ++j)
230  {
231  for (ss = 0.0, k = i; k < m_; ++k)
232  {
233  ss += u_(k, i) * u_(k, j);
234  }
235 
236  f = ss / h;
237 
238  for (k = i; k < m_; ++k)
239  {
240  u_(k, j) += f * u_(k, i);
241  }
242  }
243 
244  for (k = i; k < m_; ++k)
245  {
246  u_(k, i) *= scale;
247  }
248  }
249  }
250 
251  s_[i] = scale * g;
252  g = ss = scale = 0.0;
253 
254  if (i + 1 <= m_ && i + 1 != n_)
255  {
256  for (k = l - 1; k < n_; ++k)
257  {
258  scale += std::abs(u_(i, k));
259  }
260 
261  if (scale != 0.0)
262  {
263  for (k = l - 1; k < n_; ++k)
264  {
265  u_(i, k) /= scale;
266  ss += u_(i, k) * u_(i, k);
267  }
268 
269  f = u_(i, l - 1);
270  g = -SIGN(std::sqrt(ss), f);
271  h = f * g - ss;
272  u_(i, l - 1) = f - g;
273 
274  for (k = l - 1; k < n_; ++k)
275  {
276  rv1[k] = u_(i, k) / h;
277  }
278 
279  for (j = l - 1; j < m_; ++j)
280  {
281  for (ss = 0.0, k = l - 1; k < n_; ++k)
282  {
283  ss += u_(j, k) * u_(i, k);
284  }
285 
286  for (k = l - 1; k < n_; ++k)
287  {
288  u_(j, k) += ss * rv1[k];
289  }
290  }
291 
292  for (k = l - 1; k < n_; ++k)
293  {
294  u_(i, k) *= scale;
295  }
296  }
297  }
298 
299  anorm = std::max(anorm, (std::abs(s_[i]) + std::abs(rv1[i])));
300  }
301 
302  for (i = n_ - 1; i != static_cast<uint32>(-1); --i)
303  {
304  if (i < n_ - 1)
305  {
306  if (g != 0.0)
307  {
308  for (j = l; j < n_; ++j)
309  {
310  v_(j, i) = (u_(i, j) / u_(i, l)) / g;
311  }
312 
313  for (j = l; j < n_; ++j)
314  {
315  for (ss = 0.0, k = l; k < n_; ++k)
316  {
317  ss += u_(i, k) * v_(k, j);
318  }
319 
320  for (k = l; k < n_; ++k)
321  {
322  v_(k, j) += ss * v_(k, i);
323  }
324  }
325  }
326 
327  for (j = l; j < n_; ++j)
328  {
329  v_(i, j) = v_(j, i) = 0.0;
330  }
331  }
332 
333  v_(i, i) = 1.0;
334  g = rv1[i];
335  l = i;
336  }
337 
338  for (i = std::min(m_, n_) - 1; i != static_cast<uint32>(-1); --i)
339  {
340  l = i + 1;
341  g = s_[i];
342 
343  for (j = l; j < n_; ++j)
344  {
345  u_(i, j) = 0.0;
346  }
347 
348  if (g != 0.0)
349  {
350  g = 1.0 / g;
351 
352  for (j = l; j < n_; ++j)
353  {
354  for (ss = 0.0, k = l; k < m_; ++k)
355  {
356  ss += u_(k, i) * u_(k, j);
357  }
358 
359  f = (ss / u_(i, i)) * g;
360 
361  for (k = i; k < m_; ++k)
362  {
363  u_(k, j) += f * u_(k, i);
364  }
365  }
366 
367  for (j = i; j < m_; ++j)
368  {
369  u_(j, i) *= g;
370  }
371 
372  }
373  else
374  {
375  for (j = i; j < m_; ++j)
376  {
377  u_(j, i) = 0.0;
378  }
379  }
380 
381  ++u_(i, i);
382  }
383 
384  for (k = n_ - 1; k != static_cast<uint32>(-1); --k)
385  {
386  for (its = 0; its < 30; ++its)
387  {
388  flag = true;
389  for (l = k; l != static_cast<uint32>(-1); --l)
390  {
391  nm = l - 1;
392  if (l == 0 || std::abs(rv1[l]) <= eps_ * anorm)
393  {
394  flag = false;
395  break;
396  }
397 
398  if (std::abs(s_[nm]) <= eps_ * anorm)
399  {
400  break;
401  }
402  }
403 
404  if (flag)
405  {
406  c = 0.0;
407  ss = 1.0;
408  for (i = l; i < k + 1; ++i)
409  {
410  f = ss * rv1[i];
411  rv1[i] = c * rv1[i];
412 
413  if (std::abs(f) <= eps_ * anorm)
414  {
415  break;
416  }
417 
418  g = s_[i];
419  h = pythag(f, g);
420  s_[i] = h;
421  h = 1.0 / h;
422  c = g * h;
423  ss = -f * h;
424 
425  for (j = 0; j < m_; ++j)
426  {
427  y = u_(j, nm);
428  z = u_(j, i);
429  u_(j, nm) = y * c + z * ss;
430  u_(j, i) = z * c - y * ss;
431  }
432  }
433  }
434 
435  z = s_[k];
436  if (l == k)
437  {
438  if (z < 0.0)
439  {
440  s_[k] = -z;
441  for (j = 0; j < n_; ++j)
442  {
443  v_(j, k) = -v_(j, k);
444  }
445  }
446  break;
447  }
448 
449  if (its == 29)
450  {
451  THROW_INVALID_ARGUMENT_ERROR("no convergence in 30 svdcmp iterations");
452  }
453 
454  x = s_[l];
455  nm = k - 1;
456  y = s_[nm];
457  g = rv1[nm];
458  h = rv1[k];
459  f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y);
460  g = pythag(f, 1.0);
461  f = ((x - z) * (x + z) + h * ((y / (f + SIGN(g, f))) - h)) / x;
462  c = ss = 1.0;
463 
464  for (j = l; j <= nm; j++)
465  {
466  i = j + 1;
467  g = rv1[i];
468  y = s_[i];
469  h = ss * g;
470  g = c * g;
471  z = pythag(f, h);
472  rv1[j] = z;
473  c = f / z;
474  ss = h / z;
475  f = x * c + g * ss;
476  g = g * c - x * ss;
477  h = y * ss;
478  y *= c;
479 
480  for (jj = 0; jj < n_; ++jj)
481  {
482  x = v_(jj, j);
483  z = v_(jj, i);
484  v_(jj, j) = x * c + z * ss;
485  v_(jj, i) = z * c - x * ss;
486  }
487 
488  z = pythag(f, h);
489  s_[j] = z;
490 
491  if (z != 0.0)
492  {
493  z = 1.0 / z;
494  c = f * z;
495  ss = h * z;
496  }
497 
498  f = c * g + ss * y;
499  x = c * y - ss * g;
500 
501  for (jj = 0; jj < m_; ++jj)
502  {
503  y = u_(jj, j);
504  z = u_(jj, i);
505  u_(jj, j) = y * c + z * ss;
506  u_(jj, i) = z * c - y * ss;
507  }
508  }
509  rv1[l] = 0.0;
510  rv1[k] = f;
511  s_[k] = x;
512  }
513  }
514  }
515 
516  // =============================================================================
517  // Description:
520  void reorder()
521  {
522  uint32 i = 0;
523  uint32 j = 0;
524  uint32 k = 0;
525  uint32 ss = 0;
526  uint32 inc = 1;
527 
528  double sw = 0.0;
529  NdArray<double> su(m_, 1);
530  NdArray<double> sv(n_, 1);
531 
532  do
533  {
534  inc *= 3;
535  ++inc;
536  } while (inc <= n_);
537 
538  do
539  {
540  inc /= 3;
541  for (i = inc; i < n_; ++i)
542  {
543  sw = s_[i];
544  for (k = 0; k < m_; ++k)
545  {
546  su[k] = u_(k, i);
547  }
548 
549  for (k = 0; k < n_; ++k)
550  {
551  sv[k] = v_(k, i);
552  }
553 
554  j = i;
555  while (s_[j - inc] < sw)
556  {
557  s_[j] = s_[j - inc];
558 
559  for (k = 0; k < m_; ++k)
560  {
561  u_(k, j) = u_(k, j - inc);
562  }
563 
564  for (k = 0; k < n_; ++k)
565  {
566  v_(k, j) = v_(k, j - inc);
567  }
568 
569  j -= inc;
570 
571  if (j < inc)
572  {
573  break;
574  }
575  }
576 
577  s_[j] = sw;
578 
579  for (k = 0; k < m_; ++k)
580  {
581  u_(k, j) = su[k];
582  }
583 
584  for (k = 0; k < n_; ++k)
585  {
586  v_(k, j) = sv[k];
587  }
588 
589  }
590  } while (inc > 1);
591 
592  for (k = 0; k < n_; ++k)
593  {
594  ss = 0;
595 
596  for (i = 0; i < m_; i++)
597  {
598  if (u_(i, k) < 0.)
599  {
600  ss++;
601  }
602  }
603 
604  for (j = 0; j < n_; ++j)
605  {
606  if (v_(j, k) < 0.)
607  {
608  ss++;
609  }
610  }
611 
612  if (ss > (m_ + n_) / 2)
613  {
614  for (i = 0; i < m_; ++i)
615  {
616  u_(i, k) = -u_(i, k);
617  }
618 
619  for (j = 0; j < n_; ++j)
620  {
621  v_(j, k) = -v_(j, k);
622  }
623  }
624  }
625  }
626 
627  // =============================================================================
628  // Description:
637  static double pythag(double inA, double inB) noexcept
638  {
639  const double absa = std::abs(inA);
640  const double absb = std::abs(inB);
641  return (absa > absb ? absa * std::sqrt(1.0 + utils::sqr(absb / absa)) :
642  (absb == 0.0 ? 0.0 : absb * std::sqrt(1.0 + utils::sqr(absa / absb))));
643  }
644 
645  private:
646  // ===============================Attributes====================================
647  const uint32 m_;
648  const uint32 n_;
649  NdArray<double> u_;
650  NdArray<double> v_;
651  NdArray<double> s_;
652  double eps_;
653  double tsh_;
654  };
655  } // namespace linalg
656 } // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
size_type size() const noexcept
Definition: NdArrayCore.hpp:4497
const_reference front() const noexcept
Definition: NdArrayCore.hpp:2933
Definition: SVDClass.hpp:48
const NdArray< double > & u() noexcept
Definition: SVDClass.hpp:77
const NdArray< double > & s() noexcept
Definition: SVDClass.hpp:101
const NdArray< double > & v() noexcept
Definition: SVDClass.hpp:89
NdArray< double > solve(const NdArray< double > &inInput, double inThresh=-1.0)
Definition: SVDClass.hpp:116
SVD(const NdArray< double > &inMatrix)
Definition: SVDClass.hpp:57
constexpr auto j
Definition: Constants.hpp:45
constexpr double c
speed of light
Definition: Constants.hpp:40
dtype f(dtype inDofN, dtype inDofD)
Definition: f.hpp:56
constexpr dtype sqr(dtype inValue) noexcept
Definition: sqr.hpp:44
Definition: Coordinate.hpp:45
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:45
auto abs(dtype inValue) noexcept
Definition: abs.hpp:51
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:45
auto sqrt(dtype inValue) noexcept
Definition: sqrt.hpp:50
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:44
std::uint32_t uint32
Definition: Types.hpp:40