NumCpp  2.4.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
dot.hpp
Go to the documentation of this file.
1 #pragma once
29 
30 #include "NumCpp/NdArray.hpp"
31 
32 #include <complex>
33 
34 namespace nc
35 {
36  //============================================================================
37  // Method Description:
46  template<typename dtype>
47  NdArray<dtype> dot(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2)
48  {
49  return inArray1.dot(inArray2);
50  }
51 
52  //============================================================================
53  // Method Description:
65  template<typename dtype>
66  NdArray<std::complex<dtype>> dot(const NdArray<dtype>& inArray1, const NdArray<std::complex<dtype>>& inArray2)
67  {
69 
70  const auto shape1 = inArray1.shape();
71  const auto shape2 = inArray2.shape();
72 
73  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
74  {
75  const std::complex<dtype> dotProduct = std::inner_product(inArray1.cbegin(), inArray1.cend(),
76  inArray2.cbegin(), std::complex<dtype>{0});
77  NdArray<std::complex<dtype>> returnArray = { dotProduct };
78  return returnArray;
79  }
80  if (shape1.cols == shape2.rows)
81  {
82  // 2D array, use matrix multiplication
83  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
84  auto array2T = inArray2.transpose();
85 
86  for (uint32 i = 0; i < shape1.rows; ++i)
87  {
88  for (uint32 j = 0; j < shape2.cols; ++j)
89  {
90  returnArray(i, j) = std::inner_product(array2T.cbegin(j), array2T.cend(j),
91  inArray1.cbegin(i), std::complex<dtype>{0});
92  }
93  }
94 
95  return returnArray;
96  }
97 
98  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
99  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
100  errStr += " are not consistent.";
102 
103  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
104  }
105 
106  //============================================================================
107  // Method Description:
119  template<typename dtype>
120  NdArray<std::complex<dtype>> dot(const NdArray<std::complex<dtype>>& inArray1, const NdArray<dtype>& inArray2)
121  {
123 
124  const auto shape1 = inArray1.shape();
125  const auto shape2 = inArray2.shape();
126 
127  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
128  {
129  const std::complex<dtype> dotProduct = std::inner_product(inArray1.cbegin(), inArray1.cend(),
130  inArray2.cbegin(), std::complex<dtype>{0});
131  NdArray<std::complex<dtype>> returnArray = { dotProduct };
132  return returnArray;
133  }
134  if (shape1.cols == shape2.rows)
135  {
136  // 2D array, use matrix multiplication
137  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
138  auto array2T = inArray2.transpose();
139 
140  for (uint32 i = 0; i < shape1.rows; ++i)
141  {
142  for (uint32 j = 0; j < shape2.cols; ++j)
143  {
144  returnArray(i, j) = std::inner_product(array2T.cbegin(j), array2T.cend(j),
145  inArray1.cbegin(i), std::complex<dtype>{0});
146  }
147  }
148 
149  return returnArray;
150  }
151 
152  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
153  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
154  errStr += " are not consistent.";
156 
157  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
158  }
159 } // namespace nc
nc::NdArray::shape
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4356
STATIC_ASSERT_ARITHMETIC
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:37
nc::NdArray::dot
NdArray< dtype > dot(const NdArray< dtype > &inOtherArray) const
Definition: NdArrayCore.hpp:2661
nc::utils::num2str
std::string num2str(dtype inNumber)
Definition: num2str.hpp:46
nc::NdArray::transpose
NdArray< dtype > transpose() const
Definition: NdArrayCore.hpp:4652
nc::dot
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition: dot.hpp:47
nc::NdArray< dtype >
nc::constants::j
constexpr auto j
Definition: Constants.hpp:45
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:40
NdArray.hpp
nc::NdArray::cend
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1487
nc
Definition: Coordinate.hpp:44
THROW_INVALID_ARGUMENT_ERROR
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
nc::NdArray::cbegin
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1143