#!/bin/env python3

import argparse
from typing import Iterable
import json
import logging
import platform
from pathlib import Path
import sys

from bmutils import Benchmarks, axes, build
from mdutils.mdutils import MdUtils
from mdutils.tools.TextUtils import TextUtils

other_axes = axes.copy()
x_axis = other_axes.pop("dispatch")

parser = argparse.ArgumentParser()
parser.add_argument("benchmarks", nargs="?", type=Path, default=Path("build"))
parser.add_argument("--prefix", "-p", type=Path, default="benchmarks")
parser.add_argument("--objects", "-O", type=int)
parser.add_argument("--hierarchies", "-H", type=int)
parser.add_argument("--compare", "-c")
parser.add_argument("--log-level")
args, other_args = parser.parse_known_args()

if args.log_level:
    logging.basicConfig(level=getattr(logging, args.log_level.upper()))


def with_suffixes(path: Path, *suffixes: str):
    return path.with_suffix("".join(suffixes))


if args.prefix.suffix == ".json":
    json_file = args.prefix
    print("reading results from", json_file)
    with open(json_file) as fh:
        results = Benchmarks.parse(json.load(fh))
    prefix = json_file
    while prefix.suffix:
        prefix = prefix.with_suffix("")
    suffixes = json_file.suffixes[:-1]
else:
    if args.hierarchies is not None:
        build(args.benchmarks, args.hierarchies)
    exe = args.benchmarks.joinpath("tests", "benchmarks")
    print(f"running {exe}")
    results = Benchmarks.run(exe, *other_args, objects=args.objects)
    prefix = args.prefix
    suffixes = [
        f".{suffix}" for suffix in (platform.node(), results.context.yomm2.compiler)
    ]
    json_file = with_suffixes(prefix, *suffixes, ".json")
    print("writing results to", json_file)

    with open(json_file, "w") as fh:
        json.dump(results.data, fh, indent=4)

if args.compare is None:
    reference = None
else:
    reference_file = with_suffixes(prefix.with_stem(args.compare), *suffixes, ".json")
    print("reading base results from", reference_file)

    try:
        with open(reference_file) as fh:
            reference = Benchmarks.parse(json.load(fh))
    except FileNotFoundError as error:
        sys.exit(f"no reference file {reference_file}")


def display_cv(cv):
    display_cv = f"max cv = {cv * 100:5.1f}%"
    if cv > 0.05:
        display_cv = TextUtils.bold(f"WARNING: {display_cv}")
    md.new_line(display_cv)


# =============================================================

md_file = str(with_suffixes(prefix, *suffixes, ".md"))
md = MdUtils(file_name=md_file)
print("writing results to", md_file)
results.context.write_md(md)
md.new_line(f"command line: {' '.join(sys.argv)}")
display_cv(max((bm.cv for bm in results.all)))


def product(iterable: Iterable, *more: Iterable):
    if len(more) == 0:
        return [[element] for element in iterable]

    subproduct = product(*more)

    return [[element, *subelement] for subelement in subproduct for element in iterable]


for other_tags in product(*other_axes.values()):
    md.new_header(
        level=2,
        title=", ".join([str(tag).replace("_", " ") for tag in other_tags]),
        add_table_of_contents="",
    )
    table = ["dispatch", "avg", "ratio"]
    if reference:
        table.append(f"/ {args.compare}")
    columns = len(table)
    max_cv = results.get("virtual_function", *other_tags).cv
    for dispatch in x_axis:
        result = results.get(dispatch, *other_tags)
        max_cv = max(max_cv, result.cv)
        row = [
            result.dispatch.replace("_", " "),
            f"{result.mean:6.1f}",
            "1.00" if result.base is None else f"{result.mean / result.base.mean:5.2f}",
        ]
        if reference:
            ref = reference.get(dispatch, *other_tags)
            max_cv = max(max_cv, ref.cv)

            row.append(
                "{:5.2f}".format(
                    1
                    if ref.base is None
                    else result.mean / ref.mean
                )
            )
        table.extend(row)

    md.new_table(
        columns=columns,
        rows=int(len(table) / columns),
        text=table,
        text_align="right",
    )
    display_cv(max_cv)


md.create_md_file()

# --benchmark_enable_random_interleaving=true --benchmark_repetitions=5  --benchmark_min_time=2 --benchmark_report_aggregates_only=true
