Signature | Description |
---|---|
enum class loss_function_type : unsigned char { // P = Probability(Actual), Q = Probability(Model) kullback_leibler = 1, // L = ∑[P(x) * log(P(x) / Q(x))] // y = Actual, ŷ = Model mean_abs_error = 2, // L = ∑[|yi - ŷi|] / N // y = Actual, ŷ = Model mean_sqr_error = 3, // L = ∑[(yi - ŷi)2] / N // y = Actual, ŷ = Model mean_sqr_log_error = 4, // L = ∑[(log(1 + yi) - log(1 + ŷi))2] / N // y = Actual, P(yi) = Model probability prediction cross_entropy = 5, // L = -∑[y |
Different loss function types |
Signature | Description | Parameters |
---|---|---|
#include <DataFrame/DataFrameMLVisitors.h> template<typename T, typename I = unsigned long, std::size_t A = 0> struct LossFunctionVisitor; // ------------------------------------- template<typename T, typename I = unsigned long, std::size_t A = 0> using loss_v = LossFunctionVisitor<T, I, A>; |
This is a “single action visitor”, meaning it is passed the whole data vector in one call and you must use the single_act_visit() interface. This visitor implements different loss functions specified above. It needs two columns actual and predicted or model. The result is a single figure. explicit LossFunctionVisitor(loss_function_type lft); |
T: Column data type. I: Index type. A: Memory alignment boundary for vectors. Default is system default alignment |
static void test_LossFunctionVisitor() { std::cout << "\nTesting LossFunctionVisitor{ } ..." << std::endl; using IntDataFrame = StdDataFrame256<int>; IntDataFrame df; StlVecType<int> idxvec = { 1, 2, 3, 10, 5, 7, 8, 12, 9, 12, 10, 13, 10, 15, 14 }; StlVecType<double> actual = { 1.0, 15.0, 14.0, 2.0, 1.0, 12.0, 11.0, 8.0, 7.0, 4.0, 5.0, 4.0, 3.0, 9.0, 10.0 }; StlVecType<double> bin_actual = { 1, 0, 1, 1, 1.0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1 }; StlVecType<double> model = { 1.01, 14.908, 14.03, 1.0, 1.5, 12.0, 19.75, 8.6, 7.1, 4.8, 4.4, 4.0, 3.4, 9.0, 9.098 }; StlVecType<double> model_prob = { 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667, 0.06667 }; StlVecType<std::string> strvec = { "zz", "bb", "cc", "ww", "ee", "ff", "gg", "hh", "ii", "jj", "kk", "ll", "mm", "nn", "oo" }; df.load_data(std::move(idxvec), std::make_pair("actual", actual), std::make_pair("binary actual", bin_actual), std::make_pair("model", model), std::make_pair("model_prob", model_prob), std::make_pair("str_col", strvec)); loss_v<double, int> loss { loss_function_type::kullback_leibler }; df.single_act_visit<double, double>("actual", "model_prob", loss); assert(std::abs(loss.get_result() - 517.6888) < 0.0001); loss_v<double, int> loss2 { loss_function_type::mean_abs_error }; df.single_act_visit<double, double>("actual", "model", loss2); assert(std::abs(loss2.get_result() - 0.9189) < 0.0001); loss_v<double, int> loss3 { loss_function_type::mean_sqr_error }; df.single_act_visit<double, double>("actual", "model", loss3); assert(std::abs(loss3.get_result() - 5.3444) < 0.0001); loss_v<double, int> loss4 { loss_function_type::mean_sqr_log_error }; df.single_act_visit<double, double>("actual", "model", loss4); assert(std::abs(loss4.get_result() - 0.0379) < 0.0001); loss_v<double, int> loss5 { loss_function_type::categorical_hinge }; df.single_act_visit<double, double>("actual", "model", loss5); assert(std::abs(loss5.get_result() - 0) < 0.0001); loss_v<double, int> loss6 { loss_function_type::cosine_similarity }; df.single_act_visit<double, double>("actual", "model", loss6); assert(std::abs(loss6.get_result() - 0.9722) < 0.0001); loss_v<double, int> loss7 { loss_function_type::log_cosh }; df.single_act_visit<double, double>("actual", "model", loss7); assert(std::abs(loss7.get_result() - 0.646) < 0.0001); loss_v<double, int> loss8 { loss_function_type::binary_cross_entropy }; df.single_act_visit<double, double>("binary actual", "model_prob", loss8); assert(std::abs(loss8.get_result() - 1.5972) < 0.0001); loss_v<double, int> loss9 { loss_function_type::cross_entropy }; df.single_act_visit<double, double>("actual", "model_prob", loss9); assert(std::abs(loss9.get_result() - 19.1365) < 0.0001); }