#!/usr/bin/python3

import argparse
import logging
import shutil
from pathlib import Path

from bmutils import Benchmarks, build
from mdutils.mdutils import MdUtils
from mdutils.tools.TextUtils import TextUtils

parser = argparse.ArgumentParser()
parser.add_argument("--output", "-o", type=str, required=True)
parser.add_argument("--build-dir", "-b", type=Path, default=Path("build"))
parser.add_argument("--hierarchies", type=str, default="10,10,10")
parser.add_argument("--objects", type=str, default="1000,1000,10")
parser.add_argument("--log-level")
args, other_args = parser.parse_known_args()

if args.log_level:
    logging.basicConfig(level=getattr(logging, args.log_level.upper()))

logger = logging.getLogger(__name__)

n_first, h_step, h_count = [int(arg) for arg in args.hierarchies.split(",")]
o_first, o_step, o_count = [int(arg) for arg in args.objects.split(",")]

policy = "hash_factors_in_globals"
results = []
other_tags = ("arity_1", "ordinary_base", "no_work")
max_max_cv = -1

for nh in range(n_first, n_first + h_step * h_count, h_step):
    exe = build(args.build_dir, nh)
    nh_exe = f"{exe}.{nh}"
    shutil.copy(exe, nh_exe)
    logger.info("saved to %s", nh_exe)
    row = [str(nh)]
    max_cv = -1

    for no in range(o_first, o_first + o_step * o_count, o_step):
        bm = Benchmarks.run(
            nh_exe,
            f"--benchmark_filter=(arity_0)|((virtual_function|{policy})-{'-'.join(other_tags)})",
            *other_args,
            objects=no,
        )

        virtual_result = bm.get("virtual_function", *other_tags)
        method_result = bm.get(policy, *other_tags)
        row_cv = 100 * method_result.cv
        max_cv = max(max_cv, method_result.cv)
        row.extend(
            [
                f"{x:.3f}"
                for x in (
                    virtual_result.mean,
                    method_result.mean,
                    method_result.mean / virtual_result.mean,
                )
            ]
        )

    row.append(f"{100 * max_cv:.1f}%")
    logger.info(" ".join(row))
    results.extend(row)
    max_max_cv = max(max_max_cv, max_cv)

file_name = (
    "-".join(
        (
            args.output,
            bm.context.host_name,
            bm.context.yomm2.compiler,
            bm.context.yomm2.compiler_version,
        )
    )
    + ".md"
)
logger.info("writing results to %s", file_name)

md = MdUtils(file_name=file_name)
bm.context.write_md(md)
headers = ["objets&rarr;<br>hierarchies&darr;"]
for no in range(o_first, o_first + o_step * o_count, o_step):
    headers.extend([f"{str(no)}<br>virtual", "<br>method", "<br>ratio"])
headers.append("max<br>cv")
table = headers + results

md.new_table(
    columns=len(headers),
    rows=int(len(table) / len(headers)),
    text=table,
    text_align="right",
)

display_cv = f"max cv = {100 * max_max_cv:5.1f}%"
if max_max_cv > 0.05:
    display_cv = TextUtils.bold(f"WARNING: {display_cv}")
md.new_line(display_cv)

md.create_md_file()
