NumCpp  2.9.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
matrix_power.hpp
Go to the documentation of this file.
1 #pragma once
29 
30 #include <string>
31 
34 #include "NumCpp/Core/Shape.hpp"
35 #include "NumCpp/Core/Types.hpp"
36 #include "NumCpp/Functions/dot.hpp"
38 #include "NumCpp/NdArray.hpp"
39 
40 namespace nc
41 {
42  namespace linalg
43  {
44  //============================================================================
45  // Method Description:
61  template<typename dtype>
63  {
65 
66  const Shape inShape = inArray.shape();
67  if (inShape.rows != inShape.cols)
68  {
69  THROW_INVALID_ARGUMENT_ERROR("input matrix must be square.");
70  }
71 
72  if (inPower == 0)
73  {
74  return identity<double>(inShape.rows);
75  }
76 
77  if (inPower == 1)
78  {
79  return inArray.template astype<double>();
80  }
81 
82  if (inPower == -1)
83  {
84  return inv(inArray);
85  }
86 
87  if (inPower > 1)
88  {
89  NdArray<double> inArrayDouble = inArray.template astype<double>();
90  NdArray<double> returnArray = dot(inArrayDouble, inArrayDouble);
91  for (int16 i = 2; i < inPower; ++i)
92  {
93  returnArray = dot(returnArray, inArrayDouble);
94  }
95  return returnArray;
96  }
97 
98  NdArray<double> inverse = inv(inArray);
99  NdArray<double> returnArray = dot(inverse, inverse);
100  inPower *= -1;
101  for (int16 i = 2; i < inPower; ++i)
102  {
103  returnArray = dot(returnArray, inverse);
104  }
105  return returnArray;
106  }
107  } // namespace linalg
108 } // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
#define STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype)
Definition: StaticAsserts.hpp:49
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4092
A Shape Class for NdArrays.
Definition: Core/Shape.hpp:41
uint32 rows
Definition: Core/Shape.hpp:44
uint32 cols
Definition: Core/Shape.hpp:45
NdArray< double > matrix_power(const NdArray< dtype > &inArray, int16 inPower)
Definition: matrix_power.hpp:62
NdArray< double > inv(const NdArray< dtype > &inArray)
Definition: inv.hpp:56
Definition: Coordinate.hpp:45
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition: dot.hpp:47
std::int16_t int16
Definition: Types.hpp:37