check_err.hpp Source File

check_err.hpp Source File#

Composable Kernel: check_err.hpp Source File
library/utility/check_err.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <algorithm>
7#include <cmath>
8#include <cstdlib>
9#include <iostream>
10#include <iomanip>
11#include <iterator>
12#include <limits>
13#include <type_traits>
14#include <vector>
15
16#include "ck/ck.hpp"
18#include "ck/utility/type.hpp"
20
22
23namespace ck {
24namespace utils {
25
26template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
27double get_relative_threshold(const int number_of_accumulations = 1)
28{
29 using F4 = ck::f4_t;
30 using F8 = ck::f8_t;
31 using F16 = ck::half_t;
32 using BF16 = ck::bhalf_t;
33 using F32 = float;
34 using TF32 = ck::tf32_t;
35 using I8 = int8_t;
36 using I32 = int32_t;
37
43 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
44 double compute_error = 0;
47 {
48 return 0;
49 }
50 else
51 {
52 compute_error = std::pow(2, -NumericUtils<ComputeDataType>::mant) * 0.5;
53 }
54
60 "Warning: Unhandled OutDataType for setting up the relative threshold!");
61 double output_error = 0;
64 {
65 return 0;
66 }
67 else
68 {
69 output_error = std::pow(2, -NumericUtils<OutDataType>::mant) * 0.5;
70 }
71 double midway_error = std::max(compute_error, output_error);
72
78 "Warning: Unhandled AccDataType for setting up the relative threshold!");
79 double acc_error = 0;
82 {
83 return 0;
84 }
85 else
86 {
87 acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
88 }
89 return std::max(acc_error, midway_error);
90}
91
92template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
93double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
94{
95 using F4 = ck::f4_t;
96 using F8 = ck::f8_t;
97 using F16 = ck::half_t;
98 using BF16 = ck::bhalf_t;
99 using F32 = float;
100 using TF32 = ck::tf32_t;
101 using I8 = int8_t;
102 using I32 = int32_t;
103
109 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
110 auto expo = std::log2(std::abs(max_possible_num));
111 double compute_error = 0;
114 {
115 return 0;
116 }
117 else
118 {
119 compute_error = std::pow(2, expo - NumericUtils<ComputeDataType>::mant) * 0.5;
120 }
121
127 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
128 double output_error = 0;
131 {
132 return 0;
133 }
134 else
135 {
136 output_error = std::pow(2, expo - NumericUtils<OutDataType>::mant) * 0.5;
137 }
138 double midway_error = std::max(compute_error, output_error);
139
145 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
146 double acc_error = 0;
149 {
150 return 0;
151 }
152 else
153 {
154 acc_error =
155 std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
156 }
157 return std::max(acc_error, midway_error);
158}
159
160template <typename Range,
161 typename RefRange,
162 typename ComputeDataType = ranges::range_value_t<Range>>
163typename std::enable_if<
164 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
165 std::is_same_v<ranges::range_value_t<Range>, float> &&
166 std::is_same_v<ComputeDataType, ck::tf32_t>,
167 bool>::type
168check_err(const Range& out,
169 const RefRange& ref,
170 const std::string& msg = "Error: Incorrect results!",
171 double rtol = 1e-5,
172 double atol = 3e-5)
173{
174 if(out.size() != ref.size())
175 {
176 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
177 << std::endl;
178 return false;
179 }
180
181 bool res{true};
182 int err_count = 0;
183 double err = 0;
184 double max_err = std::numeric_limits<double>::min();
185 for(std::size_t i = 0; i < ref.size(); ++i)
186 {
187 const double o = *std::next(std::begin(out), i);
188 const double r = *std::next(std::begin(ref), i);
189 err = std::abs(o - r);
190 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
191 {
192 max_err = err > max_err ? err : max_err;
193 if(err_count < 5)
194 {
195 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
196 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
197 }
198 res = false;
199 err_count++;
200 }
201 }
202 if(!res)
203 {
204 const float error_percent =
205 static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
206 std::cerr << "max err: " << max_err;
207 std::cerr << ", number of errors: " << err_count;
208 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
209 }
210 return res;
211}
212
213template <typename Range,
214 typename RefRange,
215 typename ComputeDataType = ranges::range_value_t<Range>>
216typename std::enable_if<
217 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
218 std::is_floating_point_v<ranges::range_value_t<Range>> &&
219 !std::is_same_v<ranges::range_value_t<Range>, half_t> &&
220 !std::is_same_v<ComputeDataType, ck::tf32_t>,
221 bool>::type
222check_err(const Range& out,
223 const RefRange& ref,
224 const std::string& msg = "Error: Incorrect results!",
225 double rtol = 1e-5,
226 double atol = 3e-6)
227{
228 if(out.size() != ref.size())
229 {
230 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
231 << std::endl;
232 return false;
233 }
234
235 bool res{true};
236 int err_count = 0;
237 double err = 0;
238 double max_err = std::numeric_limits<double>::min();
239 for(std::size_t i = 0; i < ref.size(); ++i)
240 {
241 const double o = *std::next(std::begin(out), i);
242 const double r = *std::next(std::begin(ref), i);
243 err = std::abs(o - r);
244 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
245 {
246 max_err = err > max_err ? err : max_err;
247 if(err_count < 5)
248 {
249 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
250 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
251 }
252 res = false;
253 err_count++;
254 }
255 }
256 if(!res)
257 {
258 const float error_percent =
259 static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
260 std::cerr << "max err: " << max_err;
261 std::cerr << ", number of errors: " << err_count;
262 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
263 }
264 return res;
265}
266
267template <typename Range,
268 typename RefRange,
269 typename ComputeDataType = ranges::range_value_t<Range>>
270typename std::enable_if<
271 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
272 std::is_same_v<ranges::range_value_t<Range>, bhalf_t>,
273 bool>::type
274check_err(const Range& out,
275 const RefRange& ref,
276 const std::string& msg = "Error: Incorrect results!",
277 double rtol = 1e-1,
278 double atol = 1e-3)
279{
280 if(out.size() != ref.size())
281 {
282 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
283 << std::endl;
284 return false;
285 }
286
287 bool res{true};
288 int err_count = 0;
289 double err = 0;
290 // TODO: This is a hack. We should have proper specialization for bhalf_t data type.
291 double max_err = std::numeric_limits<float>::min();
292 for(std::size_t i = 0; i < ref.size(); ++i)
293 {
294 const double o = type_convert<float>(*std::next(std::begin(out), i));
295 const double r = type_convert<float>(*std::next(std::begin(ref), i));
296 err = std::abs(o - r);
297 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
298 {
299 max_err = err > max_err ? err : max_err;
300 err_count++;
301 if(err_count < 5)
302 {
303 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
304 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
305 }
306 res = false;
307 }
308 }
309 if(!res)
310 {
311 const float error_percent =
312 static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
313 std::cerr << "max err: " << max_err;
314 std::cerr << ", number of errors: " << err_count;
315 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
316 }
317 return res;
318}
319
320template <typename Range,
321 typename RefRange,
322 typename ComputeDataType = ranges::range_value_t<Range>>
323typename std::enable_if<
324 std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
325 std::is_same_v<ranges::range_value_t<Range>, half_t>,
326 bool>::type
327check_err(const Range& out,
328 const RefRange& ref,
329 const std::string& msg = "Error: Incorrect results!",
330 double rtol = 1e-3,
331 double atol = 1e-3)
332{
333 if(out.size() != ref.size())
334 {
335 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
336 << std::endl;
337 return false;
338 }
339
340 bool res{true};
341 int err_count = 0;
342 double err = 0;
343 double max_err = NumericLimits<ranges::range_value_t<Range>>::Min();
344 for(std::size_t i = 0; i < ref.size(); ++i)
345 {
346 const double o = type_convert<float>(*std::next(std::begin(out), i));
347 const double r = type_convert<float>(*std::next(std::begin(ref), i));
348 err = std::abs(o - r);
349 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
350 {
351 max_err = err > max_err ? err : max_err;
352 err_count++;
353 if(err_count < 5)
354 {
355 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
356 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
357 }
358 res = false;
359 }
360 }
361 if(!res)
362 {
363 const float error_percent =
364 static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
365 std::cerr << "max err: " << max_err;
366 std::cerr << ", number of errors: " << err_count;
367 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
368 }
369 return res;
370}
371
372template <typename Range,
373 typename RefRange,
374 typename ComputeDataType = ranges::range_value_t<Range>>
375std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
376 std::is_integral_v<ranges::range_value_t<Range>> &&
377 !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
378 !std::is_same_v<ranges::range_value_t<Range>, f8_t> &&
379 !std::is_same_v<ranges::range_value_t<Range>, bf8_t>)
380#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
381 || std::is_same_v<ranges::range_value_t<Range>, int4_t>
382#endif
383 ,
384 bool>
385check_err(const Range& out,
386 const RefRange& ref,
387 const std::string& msg = "Error: Incorrect results!",
388 double = 0,
389 double atol = 0)
390{
391 if(out.size() != ref.size())
392 {
393 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
394 << std::endl;
395 return false;
396 }
397
398 bool res{true};
399 int err_count = 0;
400 int64_t err = 0;
401 int64_t max_err = std::numeric_limits<int64_t>::min();
402 for(std::size_t i = 0; i < ref.size(); ++i)
403 {
404 const int64_t o = *std::next(std::begin(out), i);
405 const int64_t r = *std::next(std::begin(ref), i);
406 err = std::abs(o - r);
407
408 if(err > atol)
409 {
410 max_err = err > max_err ? err : max_err;
411 err_count++;
412 if(err_count < 5)
413 {
414 std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
415 << std::endl;
416 }
417 res = false;
418 }
419 }
420 if(!res)
421 {
422 const float error_percent =
423 static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
424 std::cerr << "max err: " << max_err;
425 std::cerr << ", number of errors: " << err_count;
426 std::cerr << ", " << error_percent << "% wrong values" << std::endl;
427 }
428 return res;
429}
430
431template <typename Range,
432 typename RefRange,
433 typename ComputeDataType = ranges::range_value_t<Range>>
434std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
435 std::is_same_v<ranges::range_value_t<Range>, f8_t>),
436 bool>
437check_err(const Range& out,
438 const RefRange& ref,
439 const std::string& msg = "Error: Incorrect results!",
440 double rtol = 1e-3,
441 double atol = 1e-3)
442{
443 if(out.size() != ref.size())
444 {
445 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
446 << std::endl;
447 return false;
448 }
449
450 bool res{true};
451 int err_count = 0;
452 double err = 0;
453 double max_err = std::numeric_limits<float>::min();
454
455 for(std::size_t i = 0; i < ref.size(); ++i)
456 {
457 const double o = type_convert<float>(*std::next(std::begin(out), i));
458 const double r = type_convert<float>(*std::next(std::begin(ref), i));
459 err = std::abs(o - r);
460
461 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
462 {
463 max_err = err > max_err ? err : max_err;
464 err_count++;
465 if(err_count < 5)
466 {
467 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
468 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
469 }
470 res = false;
471 }
472 }
473
474 if(!res)
475 {
476 std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
477 << " number of errors: " << err_count << std::endl;
478 }
479 return res;
480}
481
482template <typename Range,
483 typename RefRange,
484 typename ComputeDataType = ranges::range_value_t<Range>>
485std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
486 std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
487 bool>
488check_err(const Range& out,
489 const RefRange& ref,
490 const std::string& msg = "Error: Incorrect results!",
491 double rtol = 1e-3,
492 double atol = 1e-3)
493{
494 if(out.size() != ref.size())
495 {
496 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
497 << std::endl;
498 return false;
499 }
500
501 bool res{true};
502 int err_count = 0;
503 double err = 0;
504 double max_err = std::numeric_limits<float>::min();
505 for(std::size_t i = 0; i < ref.size(); ++i)
506 {
507 const double o = type_convert<float>(*std::next(std::begin(out), i));
508 const double r = type_convert<float>(*std::next(std::begin(ref), i));
509 err = std::abs(o - r);
510 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
511 {
512 max_err = err > max_err ? err : max_err;
513 err_count++;
514 if(err_count < 5)
515 {
516 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
517 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
518 }
519 res = false;
520 }
521 }
522 if(!res)
523 {
524 std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
525 }
526 return res;
527}
528
529template <typename Range,
530 typename RefRange,
531 typename ComputeDataType = ranges::range_value_t<Range>>
532std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
533 std::is_same_v<ranges::range_value_t<Range>, f4_t>),
534 bool>
535check_err(const Range& out,
536 const RefRange& ref,
537 const std::string& msg = "Error: Incorrect results!",
538 double rtol = 0.5,
539 double atol = 0.5)
540{
541 if(out.size() != ref.size())
542 {
543 std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
544 << std::endl;
545 return false;
546 }
547
548 bool res{true};
549 int err_count = 0;
550 double err = 0;
551 double max_err = std::numeric_limits<float>::min();
552
553 for(std::size_t i = 0; i < ref.size(); ++i)
554 {
555 const double o = type_convert<float>(*std::next(std::begin(out), i));
556 const double r = type_convert<float>(*std::next(std::begin(ref), i));
557 err = std::abs(o - r);
558
559 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
560 {
561 max_err = err > max_err ? err : max_err;
562 err_count++;
563 if(err_count < 5)
564 {
565 std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
566 << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
567 }
568 res = false;
569 }
570 }
571
572 if(!res)
573 {
574 std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err
575 << " number of errors: " << err_count << std::endl;
576 }
577 return res;
578}
579
580} // namespace utils
581} // namespace ck
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition library/utility/ranges.hpp:28
Definition library/utility/check_err.hpp:24
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations=1)
Definition library/utility/check_err.hpp:93
double get_relative_threshold(const int number_of_accumulations=1)
Definition library/utility/check_err.hpp:27
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, float > &&std::is_same_v< ComputeDataType, ck::tf32_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-5)
Definition library/utility/check_err.hpp:168
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
_Float16 half_t
Definition data_type.hpp:31
long int64_t
Definition data_type.hpp:464
unsigned _BitInt(4) f4_t
Definition data_type.hpp:33
bf8_fnuz_t bf8_t
Definition amd_ck_fp8.hpp:1763
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
constexpr bool is_same_v
Definition type.hpp:283
_BitInt(4) int4_t
Definition data_type.hpp:32
_BitInt(19) tf32_t
Definition data_type.hpp:29
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Definition numeric_limits.hpp:309
Definition numeric_utils.hpp:10