def run_eval_tester(model): input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" output_file_name = input_file_name.parent / "utest_output.txt" assert not output_file_name.exists() articles = [ " New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County." ] _dump_articles(input_file_name, articles) score_path = str(Path(tempfile.mkdtemp()) / "scores.json") task = "translation_en_to_de" if model == T5_TINY else "summarization" testargs = f""" run_eval_search.py {model} {input_file_name} {output_file_name} --score_path {score_path} --task {task} --num_beams 2 --length_penalty 2.0 """.split() with patch.object(sys, "argv", testargs): run_generate() assert Path(output_file_name).exists() os.remove(Path(output_file_name))
def run_search(): """ Run parametric search over the desired hparam space with help of ``run_eval.py``. All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args. The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.: ``` --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false" ``` which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with: ``` --num_beams 5 --length_penalty 0.8 --early_stopping true --num_beams 5 --length_penalty 0.8 --early_stopping false [...] --num_beams 10 --length_penalty 1.2 --early_stopping false ``` On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments. """ prog = sys.argv[0] parser = argparse.ArgumentParser(usage=( "\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore" " refer to `run_eval.py -h` for the complete list.")) parser.add_argument( "--search", type=str, required=False, help= 'param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"', ) parser.add_argument( "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)") parser.add_argument("--task", type=str, help="used for task_specific_params + metrics") parser.add_argument( "--info", nargs="?", type=str, const=datetime_now(), help= ("add custom notes to be printed before the results table. If no value is passed, the current datetime" " string will be used."), ) args, args_main = parser.parse_known_args() # we share some of the args args_main.extend(["--task", args.task]) args_normal = [prog] + args_main # to support variations like translation_en_to_de" task = "translation" if "translation" in args.task else "summarization" matrix, col_names = parse_search_arg(args.search) col_names[0:0] = task_score_names[task] # score cols first col_widths = {col: len(str(col)) for col in col_names} results = [] for r in matrix: hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)} args_exp = " ".join(r).split() args_exp.extend(["--bs", str( args.bs)]) # in case we need to reduce its size due to CUDA OOM sys.argv = args_normal + args_exp # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry scores = run_generate(verbose=False) # make sure scores are first in the table result = OrderedDict() for score in task_score_names[task]: result[score] = scores[score] result.update(hparams) results.append(result) # find widest entries for k, v in result.items(): l = len(str(v)) if l > col_widths[k]: col_widths[k] = l results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True) print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names])) print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names])) for row in results_sorted: print(" | ".join( [f"{row[col]:{col_widths[col]}}" for col in col_names])) best = results_sorted[0] for score in task_score_names[task]: del best[score] best_args = [f"--{k} {v}" for k, v in best.items()] dyn_args = ["--bs", str(args.bs)] if args.info: print(f"\nInfo: {args.info}") print("\nBest score args:") print(" ".join(args_main + best_args + dyn_args)) return results_sorted