示例#1
0
def metrics_func_provider(args, tokenizer, is_test):
    """Provide metrics callback function."""
    def single_dataset_provider(split):
        if args.task.lower() == 'blank':
            return BlankLMDataset(args, split=split, tokenizer=tokenizer)
        elif args.task.lower() == 'extraction':
            return ExtractionDataset(args, split=split, tokenizer=tokenizer)
        else:
            return Seq2SeqDataset(args, split=split, tokenizer=tokenizer)

    if args.task.lower() in ['blank', 'extraction']:
        evaluater = BlankLMEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        metric_dict = {}
    else:
        evaluater = DecoderEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        if args.tokenizer_type == "BertWordPieceTokenizer":
            dataset = 'cnn_dm'
        elif args.task.lower() == 'gigaword':
            dataset = 'gigaword'
        else:
            dataset = 'cnn_dm_org'
        metric_dict = OrderedDict({
            "rouge-1":
            functools.partial(rouge_metric, metric="rouge-1", dataset=dataset),
            "rouge-2":
            functools.partial(rouge_metric, metric="rouge-2", dataset=dataset),
            "rouge-l":
            functools.partial(rouge_metric, metric="rouge-l", dataset=dataset)
        })

    def output_func(predictions, examples, output_file):
        with open(output_file + ".hyps", "w", encoding='utf-8') as output:
            for prediction in predictions:
                output.write(prediction)
                output.write("\n")
        with open(output_file + ".refs", "w", encoding='utf-8') as output:
            for example in examples:
                output.write(example.meta["ref"])
                output.write("\n")
        if args.task.lower() == 'squad_generation':
            with open(output_file + ".source", "w",
                      encoding='utf-8') as output:
                for example in examples:
                    output.write(
                        example.text_a.replace("\n", " ") + " Answer: " +
                        example.meta["answer"])
                    output.write("\n")

    return accuracy_func_provider(single_dataset_provider,
                                  metric_dict,
                                  args,
                                  is_test=is_test,
                                  eval_func=eval_func,
                                  output_func=output_func,
                                  only_rank0=False)
示例#2
0
    def metrics_func_provider():
        """Privde metrics callback function."""
        def single_dataset_provider(datapath):
            args = get_args()
            tokenizer = get_tokenizer()

            name = name_from_datapath_func(datapath)
            return Dataset(name, [datapath], tokenizer, args.seq_length)
        return accuracy_func_provider(single_dataset_provider)
示例#3
0
def metrics_func_provider():
    """Privde metrics callback function."""
    args = get_args()
    tokenizer = get_tokenizer()

    def single_dataset_provider(datapath):
        name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
        return RaceDataset(name, [datapath], tokenizer, args.seq_length)

    return accuracy_func_provider(single_dataset_provider)
示例#4
0
文件: finetune.py 项目: puraminy/GLM
def metrics_func_provider(args, tokenizer, is_test):
    """Privde metrics callback function."""
    def single_dataset_provider(split):
        return SuperGlueDataset(args, split, tokenizer)

    output_func = get_output_func(args.task.lower(), args)
    eval_func = None
    if args.task.lower(
    ) == 'wsc' and args.cloze_eval and not args.wsc_negative:
        from tasks.language_model.finetune import classify_evaluate
        eval_func = classify_evaluate
    metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()])
    return accuracy_func_provider(single_dataset_provider,
                                  metric_dict,
                                  args,
                                  is_test=is_test,
                                  eval_func=eval_func,
                                  output_func=output_func)
示例#5
0
def metrics_func_provider(args, tokenizer, is_test):
    """Privde metrics callback function."""
    if not is_test:
        return None

    def single_dataset_provider(split):
        if args.task.lower() == 'blank':
            return BlankLMDataset(args, split=split, tokenizer=tokenizer)
        else:
            return Seq2SeqDataset(args, split=split, tokenizer=tokenizer)

    if args.task.lower() == 'blank':
        evaluater = BlankLMEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        metric_dict = {}
    else:
        evaluater = DecoderEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        metric_dict = OrderedDict({})

    def output_func(predictions, examples, output_file):
        with open(output_file + ".hyps", "w", encoding='utf-8') as output:
            for prediction in predictions:
                output.write(prediction)
                output.write("\n")
        with open(output_file + ".refs", "w", encoding='utf-8') as output:
            for example in examples:
                output.write(example.meta["ref"])
                output.write("\n")

    return accuracy_func_provider(single_dataset_provider,
                                  metric_dict,
                                  args,
                                  is_test=is_test,
                                  eval_func=eval_func,
                                  output_func=output_func,
                                  only_rank0=False)
示例#6
0
def metrics_func_provider(args, tokenizer, is_test):
    """Provide metrics callback function."""
    def single_dataset_provider(split):
        if args.task.lower() == 'blank':
            return BlankLMDataset(args, split=split, tokenizer=tokenizer)
        elif args.task.lower() == 'extraction':
            return ExtractionDataset(args, split=split, tokenizer=tokenizer)
        else:
            return Seq2SeqDataset(args, split=split, tokenizer=tokenizer)

    if args.task.lower() in ['blank', 'extraction']:
        evaluater = BlankLMEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        metric_dict = {}
    else:
        evaluater = DecoderEvaluater(args, tokenizer)
        eval_func = evaluater.evaluate
        if args.tokenizer_type == "BertWordPieceTokenizer":
            dataset = 'cnn_dm'
        elif args.task.lower() == 'gigaword':
            dataset = 'gigaword'
        else:
            dataset = 'cnn_dm_org'
        if args.task.lower() in ['squad', 'squad_v1']:
            metric_dict = {"EM": squad_exact_match, "F1": squad_f1}
        else:
            metric_dict = OrderedDict({
                "rouge-1":
                functools.partial(rouge_metric,
                                  metric="rouge-1",
                                  dataset=dataset),
                "rouge-2":
                functools.partial(rouge_metric,
                                  metric="rouge-2",
                                  dataset=dataset),
                "rouge-l":
                functools.partial(rouge_metric,
                                  metric="rouge-l",
                                  dataset=dataset)
            })

    def output_func(predictions, examples, output_file):
        if args.task.lower() in ['squad', 'squad_v1']:
            with open(output_file, "w", encoding='utf-8') as output:
                res = {}
                for prediction, example in zip(predictions, examples):
                    idx = example.idx
                    if prediction.lower().replace(' ', '') == 'n/a':
                        prediction = ''
                    if idx not in res or res[idx] == '':
                        res[idx] = prediction
                json.dump(res, output)
            with open(output_file + ".refs", "w", encoding='utf-8') as output:
                for prediction, example in zip(predictions, examples):
                    res = {
                        'id': example.idx,
                        'pred': prediction,
                        'gold': example.meta['answers']
                    }
                    output.write(json.dumps(res) + '\n')
            return
        with open(output_file + ".hyps", "w", encoding='utf-8') as output:
            for prediction in predictions:
                output.write(prediction)
                output.write("\n")
        with open(output_file + ".refs", "w", encoding='utf-8') as output:
            for example in examples:
                output.write(example.meta["ref"])
                output.write("\n")
        if args.task.lower() == 'squad_generation':
            with open(output_file + ".source", "w",
                      encoding='utf-8') as output:
                for example in examples:
                    output.write(
                        example.text_a.replace("\n", " ") + " Answer: " +
                        example.meta["answer"])
                    output.write("\n")

    return accuracy_func_provider(single_dataset_provider,
                                  metric_dict,
                                  args,
                                  is_test=is_test,
                                  eval_func=eval_func,
                                  output_func=output_func,
                                  only_rank0=False)