Beispiel #1
0
def run_eval_tester(model):
    input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
    output_file_name = input_file_name.parent / "utest_output.txt"
    assert not output_file_name.exists()
    articles = [
        " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
    ]
    _dump_articles(input_file_name, articles)
    score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
    task = "translation_en_to_de" if model == T5_TINY else "summarization"
    testargs = f"""
        run_eval_search.py
        {model}
        {input_file_name}
        {output_file_name}
        --score_path {score_path}
        --task {task}
        --num_beams 2
        --length_penalty 2.0
        """.split()

    with patch.object(sys, "argv", testargs):
        run_generate()
        assert Path(output_file_name).exists()
        os.remove(Path(output_file_name))
def run_search():
    """
     Run parametric search over the desired hparam space with help of ``run_eval.py``.

     All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.

    The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
    ```
     --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
    ```
    which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:

    ```
     --num_beams 5 --length_penalty 0.8 --early_stopping true
     --num_beams 5 --length_penalty 0.8 --early_stopping false
     [...]
     --num_beams 10 --length_penalty 1.2 --early_stopping false
    ```

    On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.


    """
    prog = sys.argv[0]

    parser = argparse.ArgumentParser(usage=(
        "\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
        " refer to `run_eval.py -h` for the complete list."))
    parser.add_argument(
        "--search",
        type=str,
        required=False,
        help=
        'param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
    )
    parser.add_argument(
        "--bs",
        type=int,
        default=8,
        required=False,
        help="initial batch size (may get reduced if it's too big)")
    parser.add_argument("--task",
                        type=str,
                        help="used for task_specific_params + metrics")
    parser.add_argument(
        "--info",
        nargs="?",
        type=str,
        const=datetime_now(),
        help=
        ("add custom notes to be printed before the results table. If no value is passed, the current datetime"
         " string will be used."),
    )
    args, args_main = parser.parse_known_args()
    # we share some of the args
    args_main.extend(["--task", args.task])
    args_normal = [prog] + args_main

    # to support variations like translation_en_to_de"
    task = "translation" if "translation" in args.task else "summarization"

    matrix, col_names = parse_search_arg(args.search)
    col_names[0:0] = task_score_names[task]  # score cols first
    col_widths = {col: len(str(col)) for col in col_names}
    results = []
    for r in matrix:
        hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
        args_exp = " ".join(r).split()
        args_exp.extend(["--bs", str(
            args.bs)])  # in case we need to reduce its size due to CUDA OOM
        sys.argv = args_normal + args_exp

        # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry

        scores = run_generate(verbose=False)
        # make sure scores are first in the table
        result = OrderedDict()
        for score in task_score_names[task]:
            result[score] = scores[score]
        result.update(hparams)
        results.append(result)

        # find widest entries
        for k, v in result.items():
            l = len(str(v))
            if l > col_widths[k]:
                col_widths[k] = l

    results_sorted = sorted(results,
                            key=operator.itemgetter(*task_score_names[task]),
                            reverse=True)
    print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
    print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
    for row in results_sorted:
        print(" | ".join(
            [f"{row[col]:{col_widths[col]}}" for col in col_names]))

    best = results_sorted[0]
    for score in task_score_names[task]:
        del best[score]
    best_args = [f"--{k} {v}" for k, v in best.items()]
    dyn_args = ["--bs", str(args.bs)]
    if args.info:
        print(f"\nInfo: {args.info}")
    print("\nBest score args:")
    print(" ".join(args_main + best_args + dyn_args))

    return results_sorted