NumCpp  2.4.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
pivotLU_decomposition.hpp
Go to the documentation of this file.
1 #pragma once
33 
36 #include "NumCpp/Core/Types.hpp"
37 #include "NumCpp/Functions/eye.hpp"
39 #include "NumCpp/NdArray.hpp"
41 
42 #include <cmath>
43 #include <tuple>
44 
45 namespace nc
46 {
47  namespace linalg
48  {
49  //============================================================================
50  // Method Description:
57  template<typename dtype>
58  std::tuple<NdArray<double>, NdArray<double>, NdArray<double> > pivotLU_decomposition(const NdArray<dtype>& inMatrix)
59  {
61 
62  const auto shape = inMatrix.shape();
63 
64  if(!shape.issquare())
65  {
66  THROW_RUNTIME_ERROR("Input matrix should be square.");
67  }
68 
69  NdArray<double> lMatrix = zeros_like<double>(inMatrix);
70  NdArray<double> uMatrix = inMatrix.template astype<double>();
71  NdArray<double> pMatrix = eye<double>(shape.rows);
72 
73  for(uint32 k = 0; k < shape.rows; ++k)
74  {
75  double max = 0.0;
76  uint32 pk = 0;
77  for(uint32 i = k; i < shape.rows; ++i)
78  {
79  double s = 0.0;
80  for(uint32 j = k; j < shape.cols; ++j)
81  {
82  s += std::fabs(uMatrix(i, j));
83  }
84 
85  const double q = std::fabs(uMatrix(i, k)) / s;
86  if(q > max)
87  {
88  max = q;
89  pk = i;
90  }
91  }
92 
93  if (utils::essentiallyEqual(max, double{0.0}))
94  {
95  THROW_RUNTIME_ERROR("Division by 0.");
96  }
97 
98  if(pk != k)
99  {
100  for(uint32 j = 0; j < shape.cols; ++j)
101  {
102  std::swap(pMatrix(k, j), pMatrix(pk, j));
103  std::swap(lMatrix(k, j), lMatrix(pk, j));
104  std::swap(uMatrix(k, j), uMatrix(pk, j));
105  }
106  }
107 
108  for(uint32 i = k+1; i < shape.rows; ++i)
109  {
110  lMatrix(i, k) = uMatrix(i, k) / uMatrix(k, k);
111 
112  for(uint32 j = k; j < shape.cols; ++j)
113  {
114  uMatrix(i, j) = uMatrix(i, j) - lMatrix(i, k) * uMatrix(k, j);
115  }
116  }
117  }
118 
119  for(uint32 k = 0; k < shape.rows; ++k)
120  {
121  lMatrix(k, k) = 1.0;
122  }
123 
124  return std::make_tuple(lMatrix, uMatrix, pMatrix);
125  }
126  } // namespace linalg
127 } // namespace nc
StaticAsserts.hpp
nc::Shape::issquare
bool issquare() const noexcept
Definition: Core/Shape.hpp:123
nc::NdArray::shape
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4356
Error.hpp
STATIC_ASSERT_ARITHMETIC
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:37
nc::shape
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:44
nc::utils::essentiallyEqual
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:52
nc::NdArray< double >
nc::constants::j
constexpr auto j
Definition: Constants.hpp:45
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:40
zeros_like.hpp
NdArray.hpp
nc::Shape::cols
uint32 cols
Definition: Core/Shape.hpp:45
THROW_RUNTIME_ERROR
#define THROW_RUNTIME_ERROR(msg)
Definition: Error.hpp:37
nc
Definition: Coordinate.hpp:44
nc::swap
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
nc::Shape::rows
uint32 rows
Definition: Core/Shape.hpp:44
essentiallyEqual.hpp
nc::max
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:45
eye.hpp
Types.hpp
nc::linalg::pivotLU_decomposition
std::tuple< NdArray< double >, NdArray< double >, NdArray< double > > pivotLU_decomposition(const NdArray< dtype > &inMatrix)
Definition: pivotLU_decomposition.hpp:58