def main(): apb = ArgumentParserBuilder() apb.add_opts( opt('--dataset', type=Path, default='data/kaggle-lit-review-0.1.json'), opt('--method', required=True, type=str, choices=METHOD_CHOICES), opt('--model-name', type=str), opt('--split', type=str, default='nq', choices=('nq', 'kq')), opt('--batch-size', '-bsz', type=int, default=96), opt('--device', type=str, default='cuda:0'), opt('--tokenizer-name', type=str), opt('--do-lower-case', action='store_true'), opt('--metrics', type=str, nargs='+', default=metric_names(), choices=metric_names())) args = apb.parser.parse_args() options = KaggleEvaluationOptions(**vars(args)) ds = LitReviewDataset.from_file(str(options.dataset)) examples = ds.to_senticized_dataset(SETTINGS.cord19_index_path, split=options.split) construct_map = dict(transformer=construct_transformer, bm25=construct_bm25, t5=construct_t5, seq_class_transformer=construct_seq_class_transformer, qa_transformer=construct_qa_transformer, random=lambda _: RandomReranker()) reranker = construct_map[options.method](options) evaluator = RerankerEvaluator(reranker, options.metrics) width = max(map(len, args.metrics)) + 1 stdout = [] for metric in evaluator.evaluate(examples): logging.info(f'{metric.name:<{width}}{metric.value:.5}') stdout.append(f'{metric.name}\t{metric.value}') print('\n'.join(stdout))
def main(): apb = ArgumentParserBuilder() apb.add_opts(opt('--dataset', type=Path, required=True), opt('--index-dir', type=Path, required=True), opt('--method', required=True, type=str, choices=METHOD_CHOICES), opt('--model-name', type=str), opt('--split', type=str, default='nq', choices=('nq', 'kq')), opt('--batch-size', '-bsz', type=int, default=96), opt('--device', type=str, default='cuda:0'), opt('--tokenizer-name', type=str), opt('--do-lower-case', action='store_true'), opt('--metrics', type=str, nargs='+', default=metric_names(), choices=metric_names())) args = apb.parser.parse_args() options = KaggleEvaluationOptions(**vars(args)) ds = LitReviewDataset.from_file(str(options.dataset)) examples = ds.to_senticized_dataset(str(options.index_dir), split=options.split) construct_map = dict(transformer=construct_transformer, bm25=construct_bm25, t5=construct_t5, seq_class_transformer=construct_seq_class_transformer, qa_transformer=construct_qa_transformer, random=lambda _: RandomReranker()) reranker = construct_map[options.method](options) evaluator = RerankerEvaluator(reranker, options.metrics) width = max(map(len, args.metrics)) + 1 stdout = [] import time start = time.time() with open(f'{options.model_name.replace("/","_")}.csv', 'w') as fd: logging.info('writing %s.csv', options.model_name) for metric in evaluator.evaluate(examples): logging.info(f'{metric.name:<{width}}{metric.value:.5}') stdout.append(f'{metric.name}\t{metric.value:.3}') fd.write(f"{metric.name}\t{metric.value:.3}\n") end = time.time() fd.write(f"time\t{end-start:.3}\n") print('\n'.join(stdout))