NumCpp  2.10.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
NdArrayBroadcast.hpp
Go to the documentation of this file.
1 #pragma once
29 
30 #include <cmath>
31 #include <utility>
32 
35 #include "NumCpp/Core/Types.hpp"
37 
38 namespace nc::broadcast
39 {
40  //============================================================================
41  // Method Description:
51  template<typename dtypeIn1, typename dtypeIn2, typename Function, typename... AdditionalFunctionArgs>
53  const NdArray<dtypeIn2>& inArray2,
54  const Function& function,
55  const AdditionalFunctionArgs&&... additionalFunctionArgs)
56  {
57  if (inArray1.shape() == inArray2.shape())
58  {
60  inArray1.cbegin(),
61  inArray1.cend(),
62  inArray2.cbegin(),
63  inArray1.begin(),
64  [&function, &additionalFunctionArgs...](const auto& inValue1, const auto& inValue2) -> dtypeIn1 {
65  return function(inValue1,
66  inValue2,
67  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
68  });
69  }
70  else if (inArray2.isscalar())
71  {
72  const auto value = inArray2.item();
74  inArray1.cbegin(),
75  inArray1.cend(),
76  inArray1.begin(),
77  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1
78  { return function(inValue, value, std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...); });
79  }
80  else if (inArray2.isflat())
81  {
82  if (inArray2.numRows() > 1 && inArray2.numRows() == inArray1.numRows())
83  {
84  for (uint32 row = 0; row < inArray1.numRows(); ++row)
85  {
86  const auto value = inArray2[row];
88  inArray1.cbegin(row),
89  inArray1.cend(row),
90  inArray1.begin(row),
91  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1 {
92  return function(inValue,
93  value,
94  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
95  });
96  }
97  }
98  else if (inArray2.numCols() > 1 && inArray2.numCols() == inArray1.numCols())
99  {
100  for (uint32 col = 0; col < inArray1.numCols(); ++col)
101  {
102  const auto value = inArray2[col];
104  inArray1.ccolbegin(col),
105  inArray1.ccolend(col),
106  inArray1.colbegin(col),
107  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1 {
108  return function(inValue,
109  value,
110  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
111  });
112  }
113  }
114  else
115  {
116  THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
117  }
118  }
119  else
120  {
121  THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
122  }
123 
124  return inArray1;
125  }
126 
127  //============================================================================
128  // Method Description:
138  template<typename dtypeOut,
139  typename dtypeIn1,
140  typename dtypeIn2,
141  typename Function,
142  typename... AdditionalFunctionArgs>
144  const NdArray<dtypeIn2>& inArray2,
145  const Function& function,
146  const AdditionalFunctionArgs&&... additionalFunctionArgs)
147  {
148  if (inArray1.shape() == inArray2.shape())
149  {
150  return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
151  {
152  NdArray<dtypeOut> returnArray(inArray1.shape());
154  inArray1.cbegin(),
155  inArray1.cend(),
156  inArray2.cbegin(),
157  returnArray.begin(),
158  [&function, &additionalFunctionArgs...](const auto& inValue1, const auto& inValue2) -> dtypeOut {
159  return function(inValue1,
160  inValue2,
161  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
162  });
163 
164  return returnArray;
165  }();
166  }
167  else if (inArray1.isscalar())
168  {
169  const auto value = inArray1.item();
170  return [&inArray2, &value, &function, &additionalFunctionArgs...]
171  {
172  NdArray<dtypeOut> returnArray(inArray2.shape());
174  inArray2.cbegin(),
175  inArray2.cend(),
176  returnArray.begin(),
177  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
178  return function(inValue,
179  value,
180  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
181  });
182  return returnArray;
183  }();
184  }
185  else if (inArray2.isscalar())
186  {
187  const auto value = inArray2.item();
188  return [&inArray1, &value, &function, &additionalFunctionArgs...]
189  {
190  NdArray<dtypeOut> returnArray(inArray1.shape());
192  inArray1.cbegin(),
193  inArray1.cend(),
194  returnArray.begin(),
195  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
196  return function(inValue,
197  value,
198  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
199  });
200  return returnArray;
201  }();
202  }
203  else if (inArray1.isflat() && inArray2.isflat())
204  {
205  return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
206  {
207  const auto numRows = std::max(inArray1.numRows(), inArray2.numRows());
208  const auto numCols = std::max(inArray1.numCols(), inArray2.numCols());
209  NdArray<dtypeOut> returnArray(numRows, numCols);
210  if (inArray1.numRows() > 1)
211  {
212  for (uint32 row = 0; row < inArray1.numRows(); ++row)
213  {
214  for (uint32 col = 0; col < inArray2.numCols(); ++col)
215  {
216  returnArray(row, col) =
217  function(inArray1[row],
218  inArray2[col],
219  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
220  }
221  }
222  }
223  else
224  {
225  for (uint32 row = 0; row < inArray2.numRows(); ++row)
226  {
227  for (uint32 col = 0; col < inArray1.numCols(); ++col)
228  {
229  returnArray(row, col) =
230  function(inArray1[col],
231  inArray2[row],
232  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
233  }
234  }
235  }
236  return returnArray;
237  }();
238  }
239  else if (inArray1.isflat())
240  {
241  return broadcaster<dtypeOut>(inArray2,
242  inArray1,
243  function,
244  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
245  }
246  else if (inArray2.isflat())
247  {
248  if (inArray2.numRows() > 1 && inArray2.numRows() == inArray1.numRows())
249  {
250  return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
251  {
252  NdArray<dtypeOut> returnArray(inArray1.shape());
253  for (uint32 row = 0; row < inArray1.numRows(); ++row)
254  {
255  const auto value = inArray2[row];
257  inArray1.cbegin(row),
258  inArray1.cend(row),
259  returnArray.begin(row),
260  [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
261  return function(inValue,
262  value,
263  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
264  });
265  }
266  return returnArray;
267  }();
268  }
269  else if (inArray2.numCols() > 1 && inArray2.numCols() == inArray1.numCols())
270  {
271  return broadcaster<dtypeOut>(inArray1.transpose(),
272  inArray2.transpose(),
273  function,
274  std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...)
275  .transpose();
276  }
277  else
278  {
279  THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
280  }
281  }
282  else
283  {
284  THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
285  }
286 
287  return {}; // get rid of compiler warning
288  }
289 } // namespace nc::broadcast
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:138
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1308
const_column_iterator ccolbegin() const noexcept
Definition: NdArrayCore.hpp:1385
self_type transpose() const
Definition: NdArrayCore.hpp:4775
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4402
bool isflat() const noexcept
Definition: NdArrayCore.hpp:2868
size_type numCols() const noexcept
Definition: NdArrayCore.hpp:3388
column_iterator colbegin() noexcept
Definition: NdArrayCore.hpp:1335
size_type numRows() const noexcept
Definition: NdArrayCore.hpp:3400
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1616
iterator begin() noexcept
Definition: NdArrayCore.hpp:1258
bool isscalar() const noexcept
Definition: NdArrayCore.hpp:2879
value_type item() const
Definition: NdArrayCore.hpp:2945
const_column_iterator ccolend() const noexcept
Definition: NdArrayCore.hpp:1770
Definition: NdArrayBroadcast.hpp:39
NdArray< dtypeIn1 > & broadcaster(NdArray< dtypeIn1 > &inArray1, const NdArray< dtypeIn2 > &inArray2, const Function &function, const AdditionalFunctionArgs &&... additionalFunctionArgs)
Definition: NdArrayBroadcast.hpp:52
OutputIt transform(InputIt first, InputIt last, OutputIt destination, UnaryOperation unaryFunction)
Definition: StlAlgorithms.hpp:775
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44
std::uint32_t uint32
Definition: Types.hpp:40