NumCpp  2.5.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
cross.hpp
Go to the documentation of this file.
1 #pragma once
29 
32 #include "NumCpp/Core/Shape.hpp"
33 #include "NumCpp/Core/Types.hpp"
34 #include "NumCpp/NdArray.hpp"
35 
36 #include <string>
37 
38 namespace nc
39 {
40  //============================================================================
41  // Method Description:
52  template<typename dtype>
53  NdArray<dtype> cross(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2, Axis inAxis = Axis::NONE)
54  {
56 
57  if (inArray1.shape() != inArray2.shape())
58  {
59  THROW_INVALID_ARGUMENT_ERROR("the input array dimensions are not consistant.");
60  }
61 
62  switch (inAxis)
63  {
64  case Axis::NONE:
65  {
66  const uint32 arraySize = inArray1.size();
67  if (arraySize != inArray2.size() || arraySize < 2 || arraySize > 3)
68  {
69  THROW_INVALID_ARGUMENT_ERROR("incompatible dimensions for cross product (dimension must be 2 or 3)");
70  }
71 
72  NdArray<dtype> in1 = inArray1.flatten();
73  NdArray<dtype> in2 = inArray2.flatten();
74 
75  switch (arraySize)
76  {
77  case 2:
78  {
79  NdArray<dtype> returnArray = { in1[0] * in2[1] - in1[1] * in2[0] };
80  return returnArray;
81  }
82  case 3:
83  {
84  dtype i = in1[1] * in2[2] - in1[2] * in2[1];
85  dtype j = -(in1[0] * in2[2] - in1[2] * in2[0]);
86  dtype k = in1[0] * in2[1] - in1[1] * in2[0];
87 
88  NdArray<dtype> returnArray = { i, j, k };
89  return returnArray;
90  }
91  default:
92  {
93  THROW_INVALID_ARGUMENT_ERROR("Unimplemented array size.");
94  return {}; // get rid of compiler warning
95  }
96  }
97  }
98  case Axis::ROW:
99  {
100  const Shape arrayShape = inArray1.shape();
101  if (arrayShape != inArray2.shape() || arrayShape.rows < 2 || arrayShape.rows > 3)
102  {
103  THROW_INVALID_ARGUMENT_ERROR("incompatible dimensions for cross product (dimension must be 2 or 3)");
104  }
105 
106  Shape returnArrayShape;
107  returnArrayShape.cols = arrayShape.cols;
108  if (arrayShape.rows == 2)
109  {
110  returnArrayShape.rows = 1;
111  }
112  else
113  {
114  returnArrayShape.rows = 3;
115  }
116 
117  NdArray<dtype> returnArray(returnArrayShape);
118  for (uint32 col = 0; col < arrayShape.cols; ++col)
119  {
120  const auto theCol = static_cast<int32>(col);
121  NdArray<dtype> vec1 = inArray1(inArray1.rSlice(), { theCol, theCol + 1 });
122  NdArray<dtype> vec2 = inArray2(inArray2.rSlice(), { theCol, theCol + 1 });
123  NdArray<dtype> vecCross = cross(vec1, vec2, Axis::NONE);
124 
125  returnArray.put({ 0, static_cast<int32>(returnArrayShape.rows) }, { theCol, theCol + 1 }, vecCross);
126  }
127 
128  return returnArray;
129  }
130  case Axis::COL:
131  {
132  const Shape arrayShape = inArray1.shape();
133  if (arrayShape != inArray2.shape() || arrayShape.cols < 2 || arrayShape.cols > 3)
134  {
135  THROW_INVALID_ARGUMENT_ERROR("incompatible dimensions for cross product (dimension must be 2 or 3)");
136  }
137 
138  Shape returnArrayShape;
139  returnArrayShape.rows = arrayShape.rows;
140  if (arrayShape.cols == 2)
141  {
142  returnArrayShape.cols = 1;
143  }
144  else
145  {
146  returnArrayShape.cols = 3;
147  }
148 
149  NdArray<dtype> returnArray(returnArrayShape);
150  for (uint32 row = 0; row < arrayShape.rows; ++row)
151  {
152  const auto theRow = static_cast<int32>(row);
153  NdArray<dtype> vec1 = inArray1({ theRow, theRow + 1 }, inArray1.cSlice());
154  NdArray<dtype> vec2 = inArray2({ theRow, theRow + 1 }, inArray2.cSlice());
155  NdArray<dtype> vecCross = cross(vec1, vec2, Axis::NONE);
156 
157  returnArray.put({ theRow, theRow + 1 }, { 0, static_cast<int32>(returnArrayShape.cols) }, vecCross);
158  }
159 
160  return returnArray;
161  }
162  default:
163  {
164  THROW_INVALID_ARGUMENT_ERROR("Unimplemented axis type.");
165  return {}; // get rid of compiler warning
166  }
167  }
168  }
169 } // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
#define STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype)
Definition: StaticAsserts.hpp:50
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:72
NdArray< dtype > & put(int32 inIndex, value_type inValue)
Definition: NdArrayCore.hpp:3856
size_type size() const noexcept
Definition: NdArrayCore.hpp:4497
NdArray< dtype > flatten() const
Definition: NdArrayCore.hpp:2919
Slice cSlice(int32 inStartIdx=0, uint32 inStepSize=1) const noexcept
Definition: NdArrayCore.hpp:998
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4483
Slice rSlice(int32 inStartIdx=0, uint32 inStepSize=1) const noexcept
Definition: NdArrayCore.hpp:1013
A Shape Class for NdArrays.
Definition: Core/Shape.hpp:41
uint32 rows
Definition: Core/Shape.hpp:44
uint32 cols
Definition: Core/Shape.hpp:45
constexpr auto j
Definition: Constants.hpp:45
Definition: Coordinate.hpp:45
NdArray< dtype > cross(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2, Axis inAxis=Axis::NONE)
Definition: cross.hpp:53
Axis
Enum To describe an axis.
Definition: Types.hpp:46
std::int32_t int32
Definition: Types.hpp:36
std::uint32_t uint32
Definition: Types.hpp:40