51 template<
typename dtype>
60 auto returnArray =
exp(inArray).template astype<double>();
61 returnArray /=
static_cast<double>(returnArray.sum().item());
66 auto returnArray =
exp(inArray).template astype<double>();
67 auto expSums = returnArray.sum(inAxis);
69 for (
uint32 row = 0; row < returnArray.shape().rows; ++row)
71 const auto rowExpSum =
static_cast<double>(expSums[row]);
73 [rowExpSum](
double& value) { value /= rowExpSum; });
80 auto returnArray =
exp(inArray.
transpose()).template astype<double>();
81 auto expSums = returnArray.sum(
Axis::COL);
83 for (
uint32 row = 0; row < returnArray.shape().rows; ++row)
85 const auto rowExpSum =
static_cast<double>(expSums[row]);
87 [rowExpSum](
double& value) { value /= rowExpSum; });
90 return returnArray.transpose();
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:36
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:37
NdArray< dtype > transpose() const
Definition: NdArrayCore.hpp:4830
NdArray< double > softmax(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: softmax.hpp:52
void for_each(InputIt first, InputIt last, UnaryFunction f)
Definition: StlAlgorithms.hpp:213
Definition: Coordinate.hpp:45
Axis
Enum To describe an axis.
Definition: Types.hpp:46
auto exp(dtype inValue) noexcept
Definition: exp.hpp:51
std::uint32_t uint32
Definition: Types.hpp:40