Coverage Report

Created: 2023-12-13 14:58

src/zserio/ReflectableUtil.h
Line
Count
Source (jump to first uncovered line)
1
#ifndef ZSERIO_REFLECTABLE_UTIL_H_INC
2
#define ZSERIO_REFLECTABLE_UTIL_H_INC
3
4
#include <algorithm>
5
#include <functional>
6
#include <cmath>
7
#include <limits>
8
9
#include "zserio/CppRuntimeException.h"
10
#include "zserio/IReflectable.h"
11
#include "zserio/ITypeInfo.h"
12
#include "zserio/StringView.h"
13
#include "zserio/Traits.h"
14
#include "zserio/TypeInfoUtil.h"
15
16
namespace zserio
17
{
18
19
namespace detail
20
{
21
22
template <typename T>
23
struct gets_value_by_value : std::integral_constant<bool,
24
        std::is_arithmetic<T>::value ||
25
        std::is_same<StringView, T>::value ||
26
        std::is_enum<T>::value ||
27
        is_bitmask<T>::value>
28
{};
29
30
} // namespace detail
31
32
/**
33
 * Utilities on zserio reflectable interface.
34
 */
35
class ReflectableUtil
36
{
37
public:
38
    /**
39
     * Makes "deep" comparison of given reflectables.
40
     *
41
     * \note Floating point values are compared using "almost equal" strategy.
42
     *
43
     * \param lhs Left-hand side reflectable.
44
     * \param rhs Right-hand side reflectable.
45
     *
46
     * \return True when the reflectables are equal, false otherwise.
47
     */
48
    template <typename ALLOC = std::allocator<uint8_t>>
49
    static bool equal(const IBasicReflectableConstPtr<ALLOC>& lhs,
50
            const IBasicReflectableConstPtr<ALLOC>& rhs);
51
52
    /**
53
     * Gets native value from the given reflectable.
54
     *
55
     * Overload for types where the value is returned by value:
56
     *
57
     * - arithmetic types, enums, bitmasks and strings (via string view).
58
     *
59
     * \param reflectable Reflectable to use for value extraction.
60
     *
61
     * \return Value of the type T.
62
     */
63
    template <typename T, typename ALLOC = std::allocator<uint8_t>,
64
            typename std::enable_if<detail::gets_value_by_value<T>::value, int>::type = 0>
65
    static T getValue(const IBasicReflectableConstPtr<ALLOC>& reflectable, const ALLOC& allocator = ALLOC())
66
32
    {
67
32
        return reflectable->getAnyValue(allocator).template get<T>();
68
32
    }
69
70
    /**
71
     * Gets constant reference to the native value from the given constant reflectable.
72
     *
73
     * Overload for types where the value is returned by const reference:
74
     *
75
     * - compound, bit buffers and arrays.
76
     *
77
     * \param reflectable Constant reflectable to use for value extraction.
78
     *
79
     * \return Constant reference to the value of the type T.
80
     *
81
     * \throw CppRuntimeException When wrong type is requested ("Bad type in AnyHolder").
82
     */
83
    template <typename T, typename ALLOC = std::allocator<uint8_t>,
84
            typename std::enable_if<!detail::gets_value_by_value<T>::value, int>::type = 0>
85
    static const T& getValue(const IBasicReflectableConstPtr<ALLOC>& reflectable,
86
            const ALLOC& allocator = ALLOC())
87
2
    {
88
2
        return reflectable->getAnyValue(allocator).template get<std::reference_wrapper<const T>>().get();
89
2
    }
90
91
    /**
92
     * Gets reference to the native value from the given reflectable.
93
     *
94
     * Overload for types where the value is returned by reference:
95
     *
96
     * - compound, arrays.
97
     *
98
     * \param reflectable Reflectable to use for value extraction.
99
     *
100
     * \return Reference to the value of the type T.
101
     *
102
     * \throw CppRuntimeException When wrong type is requested ("Bad type in AnyHolder").
103
     */
104
    template <typename T, typename ALLOC = std::allocator<uint8_t>,
105
            typename std::enable_if<
106
                    !detail::gets_value_by_value<T>::value &&
107
                    !std::is_same<BasicBitBuffer<ALLOC>, T>::value, int>::type = 0>
108
    static T& getValue(const IBasicReflectablePtr<ALLOC>& reflectable, const ALLOC& allocator = ALLOC())
109
20
    {
110
20
        return reflectable->getAnyValue(allocator).template get<std::reference_wrapper<T>>().get();
111
20
    }
112
113
    /**
114
     * Gets constant reference to the native value from the given reflectable.
115
     *
116
     * Overload for bit buffers which are currently returned only by constant reference.
117
     *
118
     * \param reflectable Reflectable to use for value extraction.
119
     *
120
     * \return Constant reference to the bit buffer value.
121
     *
122
     * \throw CppRuntimeException When wrong type is requested ("Bad type in AnyHolder").
123
     */
124
    template <typename T, typename ALLOC = std::allocator<uint8_t>,
125
            typename std::enable_if<std::is_same<BasicBitBuffer<ALLOC>, T>::value, int>::type = 0>
126
    static const T& getValue(const IBasicReflectablePtr<ALLOC>& reflectable, const ALLOC& allocator = ALLOC())
127
1
    {
128
1
        return reflectable->getAnyValue(allocator).template get<std::reference_wrapper<const T>>().get();
129
1
    }
130
131
private:
132
    template <typename ALLOC>
133
    static bool arraysEqual(const IBasicReflectableConstPtr<ALLOC>& lhsArray,
134
            const IBasicReflectableConstPtr<ALLOC>& rhsArray);
135
136
    template <typename ALLOC>
137
    static bool compoundsEqual(const IBasicReflectableConstPtr<ALLOC>& lhsCompound,
138
            const IBasicReflectableConstPtr<ALLOC>& rhsCompound);
139
140
    template <typename ALLOC>
141
    static bool valuesEqual(const IBasicReflectableConstPtr<ALLOC>& lhsValue,
142
            const IBasicReflectableConstPtr<ALLOC>& rhsValue);
143
144
    static bool doubleValuesAlmostEqual(double lhs, double rhs);
145
};
146
147
template <typename ALLOC>
148
bool ReflectableUtil::equal(const IBasicReflectableConstPtr<ALLOC>& lhs,
149
        const IBasicReflectableConstPtr<ALLOC>& rhs)
150
145
{
151
145
    if (lhs == nullptr || 
rhs == nullptr143
)
152
3
        return lhs == rhs;
153
154
142
    const auto& lhsTypeInfo = lhs->getTypeInfo();
155
142
    const auto& rhsTypeInfo = rhs->getTypeInfo();
156
157
142
    if (lhsTypeInfo.getSchemaType() != rhsTypeInfo.getSchemaType() ||
158
142
            
lhsTypeInfo.getSchemaName() != rhsTypeInfo.getSchemaName()134
)
159
9
        return false;
160
161
133
    if (lhs->isArray() || 
rhs->isArray()124
)
162
10
    {
163
10
        if (!lhs->isArray() || 
!rhs->isArray()9
)
164
2
            return false;
165
8
        return arraysEqual<ALLOC>(lhs, rhs);
166
10
    }
167
123
    else if (TypeInfoUtil::isCompound(lhsTypeInfo.getSchemaType()))
168
21
    {
169
21
        return compoundsEqual<ALLOC>(lhs, rhs);
170
21
    }
171
102
    else
172
102
    {
173
102
        return valuesEqual<ALLOC>(lhs, rhs);
174
102
    }
175
133
}
176
177
template <typename ALLOC>
178
bool ReflectableUtil::arraysEqual(const IBasicReflectableConstPtr<ALLOC>& lhsArray,
179
        const IBasicReflectableConstPtr<ALLOC>& rhsArray)
180
8
{
181
8
    if (lhsArray->size() != rhsArray->size())
182
2
        return false;
183
184
19
    
for (size_t i = 0; 6
i < lhsArray->size();
++i13
)
185
14
    {
186
14
        if (!equal<ALLOC>(lhsArray->at(i), rhsArray->at(i)))
187
1
            return false;
188
14
    }
189
190
5
    return true;
191
6
}
192
193
template <typename ALLOC>
194
bool ReflectableUtil::compoundsEqual(const IBasicReflectableConstPtr<ALLOC>& lhsCompound,
195
        const IBasicReflectableConstPtr<ALLOC>& rhsCompound)
196
21
{
197
21
    for (const auto& parameterInfo : lhsCompound->getTypeInfo().getParameters())
198
5
    {
199
5
        auto lhsParameter = lhsCompound->getParameter(parameterInfo.schemaName);
200
5
        auto rhsParameter = rhsCompound->getParameter(parameterInfo.schemaName);
201
5
        if (!equal<ALLOC>(lhsParameter, rhsParameter))
202
1
            return false;
203
5
    }
204
205
20
    if (TypeInfoUtil::hasChoice(lhsCompound->getTypeInfo().getSchemaType()))
206
16
    {
207
16
        if (lhsCompound->getChoice() != rhsCompound->getChoice())
208
4
            return false;
209
210
12
        if (!lhsCompound->getChoice().empty())
211
11
        {
212
11
            auto lhsField = lhsCompound->getField(lhsCompound->getChoice());
213
11
            auto rhsField = rhsCompound->getField(rhsCompound->getChoice());
214
11
            if (!equal<ALLOC>(lhsField, rhsField))
215
3
                return false;
216
11
        }
217
12
    }
218
4
    else
219
4
    {
220
4
        for (const auto& fieldInfo : lhsCompound->getTypeInfo().getFields())
221
8
        {
222
8
            auto lhsField = lhsCompound->getField(fieldInfo.schemaName);
223
8
            auto rhsField = rhsCompound->getField(fieldInfo.schemaName);
224
8
            if (!equal<ALLOC>(lhsField, rhsField))
225
1
                return false;
226
8
        }
227
4
    }
228
229
12
    return true;
230
20
}
231
232
template <typename ALLOC>
233
bool ReflectableUtil::valuesEqual(const IBasicReflectableConstPtr<ALLOC>& lhsValue,
234
        const IBasicReflectableConstPtr<ALLOC>& rhsValue)
235
102
{
236
102
    CppType cppType = lhsValue->getTypeInfo().getCppType();
237
102
    if (cppType == CppType::ENUM || 
cppType == CppType::BITMASK96
)
238
10
        cppType = lhsValue->getTypeInfo().getUnderlyingType().getCppType();
239
240
102
    switch (cppType)
241
102
    {
242
3
    case CppType::BOOL:
243
3
        return lhsValue->getBool() == rhsValue->getBool();
244
10
    case CppType::INT8:
245
14
    case CppType::INT16:
246
18
    case CppType::INT32:
247
22
    case CppType::INT64:
248
22
        return lhsValue->toInt() == rhsValue->toInt();
249
17
    case CppType::UINT8:
250
21
    case CppType::UINT16:
251
39
    case CppType::UINT32:
252
43
    case CppType::UINT64:
253
43
        return lhsValue->toUInt() == rhsValue->toUInt();
254
4
    case CppType::FLOAT:
255
21
    case CppType::DOUBLE:
256
21
        return doubleValuesAlmostEqual(lhsValue->toDouble(), rhsValue->toDouble());
257
4
    case CppType::BYTES:
258
4
        {
259
4
            Span<const uint8_t> lhs = lhsValue->getBytes();
260
4
            Span<const uint8_t> rhs = rhsValue->getBytes();
261
262
4
            return lhs.size() == rhs.size() && 
std::equal(lhs.begin(), lhs.end(), rhs.begin())3
;
263
4
        }
264
5
    case CppType::STRING:
265
5
        return lhsValue->getStringView() == rhsValue->getStringView();
266
3
    case CppType::BIT_BUFFER:
267
3
        return lhsValue->getBitBuffer() == rhsValue->getBitBuffer();
268
1
    default:
269
1
        throw CppRuntimeException("ReflectableUtil::valuesEqual - Unexpected C++ type!");
270
102
    }
271
102
}
272
273
inline bool ReflectableUtil::doubleValuesAlmostEqual(double lhs, double rhs)
274
21
{
275
21
    if (std::isinf(lhs) || 
std::isinf(rhs)15
)
276
8
        return std::isinf(lhs) && 
std::isinf(rhs)6
&&
(4
(4
lhs > 0.04
&&
rhs > 0.02
) ||
(3
lhs < 0.03
&&
rhs < 0.02
));
277
278
13
    if (std::isnan(lhs) || 
std::isnan(rhs)11
)
279
3
        return std::isnan(lhs) && 
std::isnan(rhs)2
;
280
281
    // see: https://en.cppreference.com/w/cpp/types/numeric_limits/epsilon
282
10
    return std::fabs(lhs - rhs) <= std::numeric_limits<double>::epsilon() * std::fabs(lhs + rhs)
283
10
            || 
std::fabs(lhs - rhs) < std::numeric_limits<double>::min()4
;
284
13
}
285
286
} // namespace zserio
287
288
#endif // ZSERIO_REFLECTABLE_UTIL_H_INC