NumCpp  2.9.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
pivotLU_decomposition.hpp
Go to the documentation of this file.
1 #pragma once
34 
35 #include <cmath>
36 #include <tuple>
37 
40 #include "NumCpp/Core/Types.hpp"
41 #include "NumCpp/Functions/eye.hpp"
43 #include "NumCpp/NdArray.hpp"
45 
46 namespace nc
47 {
48  namespace linalg
49  {
50  //============================================================================
51  // Method Description:
58  template<typename dtype>
59  std::tuple<NdArray<double>, NdArray<double>, NdArray<double>>
61  {
63 
64  const auto shape = inMatrix.shape();
65 
66  if (!shape.issquare())
67  {
68  THROW_RUNTIME_ERROR("Input matrix should be square.");
69  }
70 
71  NdArray<double> lMatrix = zeros_like<double>(inMatrix);
72  NdArray<double> uMatrix = inMatrix.template astype<double>();
73  NdArray<double> pMatrix = eye<double>(shape.rows);
74 
75  for (uint32 k = 0; k < shape.rows; ++k)
76  {
77  double max = 0.;
78  uint32 pk = 0;
79  for (uint32 i = k; i < shape.rows; ++i)
80  {
81  double s = 0.;
82  for (uint32 j = k; j < shape.cols; ++j)
83  {
84  s += std::fabs(uMatrix(i, j));
85  }
86 
87  const double q = std::fabs(uMatrix(i, k)) / s;
88  if (q > max)
89  {
90  max = q;
91  pk = i;
92  }
93  }
94 
95  if (utils::essentiallyEqual(max, double{ 0. }))
96  {
97  THROW_RUNTIME_ERROR("Division by 0.");
98  }
99 
100  if (pk != k)
101  {
102  for (uint32 j = 0; j < shape.cols; ++j)
103  {
104  std::swap(pMatrix(k, j), pMatrix(pk, j));
105  std::swap(lMatrix(k, j), lMatrix(pk, j));
106  std::swap(uMatrix(k, j), uMatrix(pk, j));
107  }
108  }
109 
110  for (uint32 i = k + 1; i < shape.rows; ++i)
111  {
112  lMatrix(i, k) = uMatrix(i, k) / uMatrix(k, k);
113 
114  for (uint32 j = k; j < shape.cols; ++j)
115  {
116  uMatrix(i, j) = uMatrix(i, j) - lMatrix(i, k) * uMatrix(k, j);
117  }
118  }
119  }
120 
121  for (uint32 k = 0; k < shape.rows; ++k)
122  {
123  lMatrix(k, k) = 1.;
124  }
125 
126  return std::make_tuple(lMatrix, uMatrix, pMatrix);
127  }
128  } // namespace linalg
129 } // namespace nc
#define THROW_RUNTIME_ERROR(msg)
Definition: Error.hpp:38
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:37
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:72
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4092
uint32 rows
Definition: Core/Shape.hpp:44
bool issquare() const noexcept
Definition: Core/Shape.hpp:125
uint32 cols
Definition: Core/Shape.hpp:45
constexpr auto j
Definition: Constants.hpp:45
std::tuple< NdArray< double >, NdArray< double >, NdArray< double > > pivotLU_decomposition(const NdArray< dtype > &inMatrix)
Definition: pivotLU_decomposition.hpp:60
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:51
Definition: Coordinate.hpp:45
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40