NumCpp  2.11.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
dot.hpp
Go to the documentation of this file.
1 #pragma once
29 
30 #include <complex>
31 
32 #include "NumCpp/NdArray.hpp"
33 
34 namespace nc
35 {
36  //============================================================================
37  // Method Description:
46  template<typename dtype>
47  NdArray<dtype> dot(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2)
48  {
49  return inArray1.dot(inArray2);
50  }
51 
52  //============================================================================
53  // Method Description:
65  template<typename dtype>
66  NdArray<std::complex<dtype>> dot(const NdArray<dtype>& inArray1, const NdArray<std::complex<dtype>>& inArray2)
67  {
69 
70  const auto shape1 = inArray1.shape();
71  const auto shape2 = inArray2.shape();
72 
73  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
74  {
75  const std::complex<dtype> dotProduct =
76  std::inner_product(inArray1.cbegin(), inArray1.cend(), inArray2.cbegin(), std::complex<dtype>{ 0 });
77  NdArray<std::complex<dtype>> returnArray = { dotProduct };
78  return returnArray;
79  }
80  if (shape1.cols == shape2.rows)
81  {
82  // 2D array, use matrix multiplication
83  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
84  auto array2T = inArray2.transpose();
85 
86  for (uint32 i = 0; i < shape1.rows; ++i)
87  {
88  for (uint32 j = 0; j < shape2.cols; ++j)
89  {
90  returnArray(i, j) = std::inner_product(array2T.cbegin(j),
91  array2T.cend(j),
92  inArray1.cbegin(i),
93  std::complex<dtype>{ 0 });
94  }
95  }
96 
97  return returnArray;
98  }
99 
100  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
101  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
102  errStr += " are not consistent.";
104 
105  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
106  }
107 
108  //============================================================================
109  // Method Description:
121  template<typename dtype>
122  NdArray<std::complex<dtype>> dot(const NdArray<std::complex<dtype>>& inArray1, const NdArray<dtype>& inArray2)
123  {
125 
126  const auto shape1 = inArray1.shape();
127  const auto shape2 = inArray2.shape();
128 
129  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
130  {
131  const std::complex<dtype> dotProduct =
132  std::inner_product(inArray1.cbegin(), inArray1.cend(), inArray2.cbegin(), std::complex<dtype>{ 0 });
133  NdArray<std::complex<dtype>> returnArray = { dotProduct };
134  return returnArray;
135  }
136  if (shape1.cols == shape2.rows)
137  {
138  // 2D array, use matrix multiplication
139  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
140  auto array2T = inArray2.transpose();
141 
142  for (uint32 i = 0; i < shape1.rows; ++i)
143  {
144  for (uint32 j = 0; j < shape2.cols; ++j)
145  {
146  returnArray(i, j) = std::inner_product(array2T.cbegin(j),
147  array2T.cend(j),
148  inArray1.cbegin(i),
149  std::complex<dtype>{ 0 });
150  }
151  }
152 
153  return returnArray;
154  }
155 
156  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
157  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
158  errStr += " are not consistent.";
160 
161  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
162  }
163 } // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:138
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1318
self_type transpose() const
Definition: NdArrayCore.hpp:4837
self_type dot(const self_type &inOtherArray) const
Definition: NdArrayCore.hpp:2672
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1626
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4464
constexpr auto j
Definition: Core/Constants.hpp:42
std::string num2str(dtype inNumber)
Definition: num2str.hpp:44
Definition: Cartesian.hpp:40
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition: dot.hpp:47
std::uint32_t uint32
Definition: Types.hpp:40