NumCpp  2.11.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
pivotLU_decomposition.hpp
Go to the documentation of this file.
1 #pragma once
34 
35 #include <cmath>
36 #include <tuple>
37 
40 #include "NumCpp/Core/Types.hpp"
41 #include "NumCpp/Functions/eye.hpp"
43 #include "NumCpp/NdArray.hpp"
45 
46 namespace nc::linalg
47 {
48  //============================================================================
49  // Method Description:
56  template<typename dtype>
57  std::tuple<NdArray<double>, NdArray<double>, NdArray<double>> pivotLU_decomposition(const NdArray<dtype>& inMatrix)
58  {
60 
61  const auto shape = inMatrix.shape();
62 
63  if (!shape.issquare())
64  {
65  THROW_RUNTIME_ERROR("Input matrix should be square.");
66  }
67 
68  NdArray<double> lMatrix = zeros_like<double>(inMatrix);
69  NdArray<double> uMatrix = inMatrix.template astype<double>();
70  NdArray<double> pMatrix = eye<double>(shape.rows);
71 
72  for (uint32 k = 0; k < shape.rows; ++k)
73  {
74  double max = 0.;
75  uint32 pk = 0;
76  for (uint32 i = k; i < shape.rows; ++i)
77  {
78  double s = 0.;
79  for (uint32 j = k; j < shape.cols; ++j)
80  {
81  s += std::fabs(uMatrix(i, j));
82  }
83 
84  const double q = std::fabs(uMatrix(i, k)) / s;
85  if (q > max)
86  {
87  max = q;
88  pk = i;
89  }
90  }
91 
92  if (utils::essentiallyEqual(max, double{ 0. }))
93  {
94  THROW_RUNTIME_ERROR("Division by 0.");
95  }
96 
97  if (pk != k)
98  {
99  for (uint32 j = 0; j < shape.cols; ++j)
100  {
101  std::swap(pMatrix(k, j), pMatrix(pk, j));
102  std::swap(lMatrix(k, j), lMatrix(pk, j));
103  std::swap(uMatrix(k, j), uMatrix(pk, j));
104  }
105  }
106 
107  for (uint32 i = k + 1; i < shape.rows; ++i)
108  {
109  lMatrix(i, k) = uMatrix(i, k) / uMatrix(k, k);
110 
111  for (uint32 j = k; j < shape.cols; ++j)
112  {
113  uMatrix(i, j) = uMatrix(i, j) - lMatrix(i, k) * uMatrix(k, j);
114  }
115  }
116  }
117 
118  for (uint32 k = 0; k < shape.rows; ++k)
119  {
120  lMatrix(k, k) = 1.;
121  }
122 
123  return std::make_tuple(lMatrix, uMatrix, pMatrix);
124  }
125 } // namespace nc::linalg
#define THROW_RUNTIME_ERROR(msg)
Definition: Error.hpp:40
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4464
uint32 rows
Definition: Core/Shape.hpp:44
bool issquare() const noexcept
Definition: Core/Shape.hpp:125
uint32 cols
Definition: Core/Shape.hpp:45
constexpr auto j
Definition: Core/Constants.hpp:42
Definition: cholesky.hpp:41
std::tuple< NdArray< double >, NdArray< double >, NdArray< double > > pivotLU_decomposition(const NdArray< dtype > &inMatrix)
Definition: pivotLU_decomposition.hpp:57
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:48
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40