Ejemplo n.º 1
0
def glue_compute_metrics(task_name, preds, labels):
    assert len(preds) == len(labels)
    if task_name == "cola":
        return {"mcc": matthews_corrcoef(labels, preds)}
    elif task_name == "sst-2":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "sst-2-orig":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "sst-2-glue":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "mrpc":
        return acc_and_f1(preds, labels)
    elif task_name == "sts-b":
        return pearson_and_spearman(preds, labels)
    elif task_name == "qqp":
        return acc_and_f1(preds, labels)
    elif task_name == "snli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "mnli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "mnli-mm":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "qnli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "rte":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "wnli":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "hans":
        return {"acc": simple_accuracy(preds, labels)}
    else:
        raise KeyError(task_name)
Ejemplo n.º 2
0
def glue_compute_metrics(task_name: str, preds: List,
                         labels: List) -> Dict[str, float]:
    assert len(preds) == len(labels)
    if task_name not in glue_processors.keys():
        raise ValueError(f"Unrecognized {task_name}")

    return {"acc": simple_accuracy(preds, labels)}
Ejemplo n.º 3
0
def evaluate(model: TransformerModelWrapper,
             eval_data: List[InputExample],
             config: EvalConfig) -> Dict:

    metrics = config.metrics if config.metrics else ['acc']
    results = model.eval(eval_data=eval_data,
                         per_gpu_eval_batch_size=config.per_gpu_eval_batch_size,
                         n_gpu=config.n_gpu)
    # print("results['logits'].shape=", results['logits'].shape)
    predictions = np.argmax(results['logits'], axis=1)
    scores = {}
    for metric in metrics:
        if metric == 'acc':
            scores[metric] = simple_accuracy(predictions, results['labels'])
        elif metric == 'f1':
            scores[metric] = f1_score(results['labels'], predictions)
        elif metric == 'f1-macro':
            scores[metric] = f1_score(results['labels'], predictions, average='macro')
        elif metric == 'em':
            scores[metric] = exact_match(predictions, results['labels'], results['question_ids'])
        else:
            raise ValueError(f"Metric '{metric}' not implemented")
    results['scores'] = scores
    results['predictions'] = predictions
    return results
Ejemplo n.º 4
0
def acc_and_micro_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    micro_f1 = f1_score(y_true=labels, y_pred=preds, average='micro')
    return {
        "acc": acc,
        "micro_f1": micro_f1,
    }
Ejemplo n.º 5
0
def multiclass_acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    macro_f1 = f1_score(y_true=labels, y_pred=preds, average='macro')
    macro_weighted_f1 = f1_score(y_true=labels,
                                 y_pred=preds,
                                 average='weighted')
    macro_precision = precision_score(y_true=labels,
                                      y_pred=preds,
                                      average='macro')
    macro_weighted_precision = precision_score(y_true=labels,
                                               y_pred=preds,
                                               average='weighted')
    macro_recall = recall_score(y_true=labels, y_pred=preds, average='macro')
    macro_weighted_recall = recall_score(y_true=labels,
                                         y_pred=preds,
                                         average='weighted')
    micro_f1 = f1_score(y_true=labels, y_pred=preds, average='micro')
    return {
        "acc": acc,
        'micro_f1': micro_f1,
        "macro_f1": macro_f1,
        "macro_weighted_f1": macro_weighted_f1,
        "macro_precision": macro_precision,
        "macro_weighted_precision": macro_weighted_precision,
        "macro_recall": macro_recall,
        "macro_weighted_recall": macro_weighted_recall,
    }
Ejemplo n.º 6
0
 def eval_dev(self, dev_data, eval_config, n_gpu):
     self.model.eval()
     results = self.eval(
         dev_data,
         per_gpu_eval_batch_size=eval_config.per_gpu_eval_batch_size,
         n_gpu=n_gpu)
     predictions = np.argmax(results['logits'], axis=1)
     scores = {}
     metrics = eval_config.metrics if eval_config.metrics else ['acc']
     for metric in metrics:
         if metric == 'acc':
             scores[metric] = simple_accuracy(predictions,
                                              results['labels'])
         elif metric == 'f1':
             scores[metric] = f1_score(results['labels'], predictions)
         elif metric == 'f1-macro':
             scores[metric] = f1_score(results['labels'],
                                       predictions,
                                       average='macro')
         elif metric == 'em':
             scores[metric] = exact_match(predictions, results['labels'],
                                          results['question_ids'])
         else:
             raise ValueError(f"Metric '{metric}' not implemented")
     return scores
Ejemplo n.º 7
0
 def _acc_and_f1(self, preds, labels):
     acc = simple_accuracy(preds, labels)
     f1 = f1_score(y_true=labels, y_pred=preds, average="weighted")
     return {
         "acc": acc,
         "f1": f1,
         "acc_and_f1": (acc + f1) / 2,
     }
Ejemplo n.º 8
0
def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds, average=None)
    return {
        "acc": acc,
        "f1": f1,
        "acc_and_f1": (acc + f1) / 2,
    }
Ejemplo n.º 9
0
    def eval(self,
             eval_data: List[InputExample],
             device,
             per_gpu_eval_batch_size: int = 8,
             n_gpu: int = 1,
             output_logits: bool = False,
             **_):

        eval_dataset = self._generate_dataset(eval_data)
        eval_batch_size = per_gpu_eval_batch_size * max(1, n_gpu)
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=eval_batch_size)

        if n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            self.model.eval()
            batch = tuple(t.to(device) for t in batch)
            labels = batch[3]
            mlm_labels = batch[4]

            with torch.no_grad():
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2]
                    if self.config.model_type in ['bert', 'xlnet'] else None
                }
                outputs = self.model(**inputs)
                logits = outputs[0]
                if self.config.wrapper_type == MLM_WRAPPER:
                    logits = self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(
                        mlm_labels, logits)
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids,
                                          labels.detach().cpu().numpy(),
                                          axis=0)

        if output_logits:
            return preds

        preds = np.argmax(preds, axis=1)
        return {"acc": simple_accuracy(preds, out_label_ids)}
Ejemplo n.º 10
0
def evaluate(
    model: TransformerModelWrapper,
    eval_data: List[InputExample],
    config: EvalConfig,
    priming_data: List[InputExample] = None,
    local_rank=-1,
) -> Dict:
    """
    Evaluate a model.

    :param model: the model to evaluate
    :param eval_data: the examples for evaluation
    :param config: the evaluation config
    :param priming_data: an optional list of priming data to use
    :return: a dictionary containing the model's logits, predictions and (if any metrics are given) scores
    """

    if config.priming:
        for example in eval_data:
            example.meta["priming_data"] = priming_data

    metrics = config.metrics if config.metrics else ["acc"]
    device = torch.device(config.device if config.device else "cuda" if torch.
                          cuda.is_available() else "cpu")

    model.model.to(device)
    results = model.eval(
        eval_data,
        device,
        per_gpu_eval_batch_size=config.per_gpu_eval_batch_size,
        n_gpu=config.n_gpu,
        decoding_strategy=config.decoding_strategy,
        priming=config.priming,
        local_rank=local_rank,
    )

    predictions = np.argmax(results["logits"], axis=1)
    scores = {}

    for metric in metrics:
        if metric == "acc":
            scores[metric] = simple_accuracy(predictions, results["labels"])
        elif metric == "f1":
            scores[metric] = f1_score(results["labels"], predictions)
        elif metric == "f1-macro":
            scores[metric] = f1_score(results["labels"],
                                      predictions,
                                      average="macro")
        elif metric == "em":
            scores[metric] = exact_match(predictions, results["labels"],
                                         results["question_ids"])
        else:
            raise ValueError(f"Metric '{metric}' not implemented")

    results["scores"] = scores
    results["predictions"] = predictions
    return results
Ejemplo n.º 11
0
def compute_metrics(task_name, preds, labels, sample_ids=None, data_dir=None):
    assert len(preds) == len(
        labels
    ), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
    if task_name == "qqp":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "entailment":
        return {"acc": simple_accuracy(preds, labels)}
    elif task_name == "sentiment":
        sentiment_acc, sentiment_macro_f1, aspect_macro_f1, aspect_strict_acc = absa_evaluation(
            data_dir, sample_ids, preds)
        return {
            "SA_acc": sentiment_acc,
            "SA_macro_f1": sentiment_macro_f1,
            "absa_macro_f1": aspect_macro_f1,
            "absa_strict_acc": aspect_strict_acc
        }
    else:
        raise KeyError(task_name)
Ejemplo n.º 12
0
def arf(preds, labels):
    p, r, f, _ = precision_recall_fscore_support(labels, preds,
                                                 beta=1,
                                                 average='binary',
                                                 warn_for=('f-score',))
    return {
        "acc": simple_accuracy(preds, labels),
        "p": p,
        "recall": r,
        "f1": f
    }
Ejemplo n.º 13
0
    def in_training_eval(self, eval_kwargs):
        eval_results = self.eval(**eval_kwargs)
        predictions = np.argmax(eval_results["logits"], axis=1)

        if eval_kwargs["metrics"]:
            if "f1" in eval_kwargs["metrics"]:
                score = f1_score(eval_results["labels"], predictions)
            elif "f1-macro" in eval_kwargs["metrics"]:
                score = f1_score(eval_results["labels"],
                                 predictions,
                                 average="macro")
            elif "em" in eval_kwargs["metrics"]:
                score = exact_match(predictions, eval_results["labels"],
                                    eval_results["question_ids"])
            else:
                score = simple_accuracy(predictions, eval_results["labels"])
        else:
            score = simple_accuracy(predictions, eval_results["labels"])

        return score
def sst3_compute_metrics(task_name, preds, labels):
    """Metrics for computing SST-3 task."""
    assert len(preds) == len(labels)
    if task_name == "sst-3":
        return {
            "acc": simple_accuracy(preds, labels),
            "pred": preds,
            "actual": labels
        }
    else:
        raise KeyError(task_name)
def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    recall = recall_score(y_true=labels, y_pred=preds, average=None)
    precision = precision_score(y_true=labels, y_pred=preds, average=None)
    f1 = f1_score(y_true=labels, y_pred=preds, average=None)

    return {
        "acc": fix_np_types(acc),
        "f1": fix_np_types(f1),
        "acc_and_f1": fix_np_types((acc + f1) / 2),
        "recall": fix_np_types(recall),
        "precision": fix_np_types(precision)
    }
Ejemplo n.º 16
0
def acc_p_r_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(
        y_true=labels,
        y_pred=preds,
    )
    recall = recall_score(
        y_true=labels,
        y_pred=preds,
    )
    precision = precision_score(
        y_true=labels,
        y_pred=preds,
    )

    return {"acc": acc, "f1": f1, 'precision': precision, 'recall': recall}
Ejemplo n.º 17
0
    def _eval_end(self, outputs) -> tuple:
        val_loss_mean = torch.stack([x["val_loss"] for x in outputs
                                     ]).mean().detach().cpu().item()
        preds = np.concatenate([x["pred"] for x in outputs], axis=0)

        preds = np.argmax(preds, axis=1)

        out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)
        out_label_list = [[] for _ in range(out_label_ids.shape[0])]
        preds_list = [[] for _ in range(out_label_ids.shape[0])]

        results = {
            **{
                "val_loss": val_loss_mean
            },
            **{
                "acc": simple_accuracy(preds, out_label_ids)
            }
        }

        ret = {k: v for k, v in results.items()}
        ret["log"] = results
        return ret, preds_list, out_label_list
Ejemplo n.º 18
0
def compute_metrics(preds, labels):
    return simple_accuracy(preds, labels)
Ejemplo n.º 19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_models", type=int, help="Number of models")
    parser.add_argument("--k",
                        type=int,
                        default=16,
                        help="Number of training instances per label")
    parser.add_argument(
        "--condition",
        type=str,
        help=
        "A dictionary contains conditions that the experiment results need to fulfill (e.g., tag, task_name, few_shot_type)"
    )

    # These options should usually be kept as their default values
    parser.add_argument("--data_dir",
                        type=str,
                        default="data/k-shot",
                        help="Data directory")
    parser.add_argument("--save_logit_dir",
                        type=str,
                        default="ensemble_predict_results",
                        help="Directory to store the logit file.")
    parser.add_argument("--log", type=str, default="log", help="Log path.")
    parser.add_argument("--key",
                        type=str,
                        default='',
                        help="Validation metric name")
    parser.add_argument("--test_key",
                        type=str,
                        default="",
                        help="Test metric name")
    parser.add_argument("--test_key2",
                        type=str,
                        default="",
                        help="Second test metric name")

    args = parser.parse_args()

    condition = eval(args.condition)

    if len(args.key) == 0:
        if condition['task_name'] == 'cola':
            args.key = 'cola_dev_eval_mcc'
            args.test_key = 'cola_test_eval_mcc'
        elif condition['task_name'] == 'mrpc/acc':
            args.key = 'mrpc_dev_eval_acc'
            args.test_key = 'mrpc_test_eval_acc'
            args.test_key2 = 'mrpc_test_eval_f1'
            condition['task_name'] = 'mrpc'
        elif condition['task_name'] == 'mrpc/f1':
            args.key = 'mrpc_dev_eval_f1'
            args.test_key2 = 'mrpc_test_eval_acc'
            args.test_key = 'mrpc_test_eval_f1'
            condition['task_name'] = 'mrpc'
        elif condition['task_name'] == 'qqp/acc':
            args.key = 'qqp_dev_eval_acc'
            args.test_key = 'qqp_test_eval_acc'
            args.test_key2 = 'qqp_test_eval_f1'
            condition['task_name'] = 'qqp'
        elif condition['task_name'] == 'qqp/f1':
            args.key = 'qqp_dev_eval_f1'
            args.test_key2 = 'qqp_test_eval_acc'
            args.test_key = 'qqp_test_eval_f1'
            condition['task_name'] = 'qqp'
        elif condition['task_name'] == 'sts-b/pearson':
            args.key = 'sts-b_dev_eval_pearson'
            args.test_key = 'sts-b_test_eval_pearson'
            args.test_key2 = 'sts-b_test_eval_spearmanr'
            condition['task_name'] = 'sts-b'
        elif condition['task_name'] == 'sts-b/spearmanr':
            args.key = 'sts-b_dev_eval_spearmanr'
            args.test_key2 = 'sts-b_test_eval_pearson'
            args.test_key = 'sts-b_test_eval_spearmanr'
            condition['task_name'] = 'sts-b'
        elif condition['task_name'] == 'qnli':
            args.key = 'qnli_dev_eval_acc'
            args.test_key = 'qnli_test_eval_acc'
        elif condition['task_name'] == 'sst-2':
            args.key = 'sst-2_dev_eval_acc'
            args.test_key = 'sst-2_test_eval_acc'
        elif condition['task_name'] == 'snli':
            args.key = 'snli_dev_eval_acc'
            args.test_key = 'snli_test_eval_acc'
        elif condition['task_name'] == 'mnli':
            args.key = 'mnli_dev_eval_mnli/acc'
            args.test_key = 'mnli_test_eval_mnli/acc'
        elif condition['task_name'] == 'mnli-mm':
            args.key = 'mnli_dev_eval_mnli/acc'
            args.test_key = 'mnli-mm_test_eval_mnli-mm/acc'
        elif condition['task_name'] == 'rte':
            args.key = 'rte_dev_eval_acc'
            args.test_key = 'rte_test_eval_acc'
        elif condition['task_name'] == 'ag_news':
            args.key = 'ag_news_dev_eval_acc'
            args.test_key = 'ag_news_test_eval_acc'
        elif condition['task_name'] == 'yahoo_answers':
            args.key = 'yahoo_answers_dev_eval_acc'
            args.test_key = 'yahoo_answers_test_eval_acc'
        elif condition['task_name'] == 'yelp_review_full':
            args.key = 'yelp_review_full_dev_eval_acc'
            args.test_key = 'yelp_review_full_test_eval_acc'
        elif condition['task_name'] == 'mr':
            args.key = 'mr_dev_eval_acc'
            args.test_key = 'mr_test_eval_acc'
        elif condition['task_name'] == 'sst-5':
            args.key = 'sst-5_dev_eval_acc'
            args.test_key = 'sst-5_test_eval_acc'
        elif condition['task_name'] == 'subj':
            args.key = 'subj_dev_eval_acc'
            args.test_key = 'subj_test_eval_acc'
        elif condition['task_name'] == 'trec':
            args.key = 'trec_dev_eval_acc'
            args.test_key = 'trec_test_eval_acc'
        elif condition['task_name'] == 'cr':
            args.key = 'cr_dev_eval_acc'
            args.test_key = 'cr_test_eval_acc'
        elif condition['task_name'] == 'mpqa':
            args.key = 'mpqa_dev_eval_acc'
            args.test_key = 'mpqa_test_eval_acc'
        else:
            raise NotImplementedError

    with open(args.log) as f:
        result_list = []
        for line in f:
            result_list.append(eval(line))

    seed_result = {}
    seed_best = {}

    # Gather all logs satisfying the conditions
    for item in result_list:
        ok = True
        for cond in condition:
            if cond == 'task_name' and condition['task_name'] == 'mnli-mm':
                if cond not in item or item[cond] != 'mnli':
                    ok = False
                    break
            else:
                if cond not in item or item[cond] != condition[cond]:
                    ok = False
                    break
        if 'model_id' not in item or 'array_id' not in item:
            ok = False

        if ok:
            seed = int(item['data_dir'].split('-')[-1])
            model_id = item['model_id']
            array_id = item['array_id']

            if model_id >= 0 and model_id < args.n_models:
                if seed not in seed_result:
                    seed_result[seed] = {}
                    seed_best[seed] = {}
                if model_id not in seed_result[seed]:
                    seed_result[seed][model_id] = []
                    seed_best[seed][model_id] = {args.key: -1e9}

                seed_result[seed][model_id].append(item)
                if item[args.key] > seed_best[seed][model_id][args.key]:
                    seed_best[seed][model_id] = item

    final_result_dev = np.zeros((len(seed_result), args.n_models))
    final_result_test = np.zeros((len(seed_result), args.n_models))
    final_result_test2 = np.zeros((len(seed_result), args.n_models))

    logit_file_list = {}
    for seed in seed_result:
        logit_file_list[seed] = []

    # Get the results for each model and pick the best dev trial for each model/seed
    for model_id in range(args.n_models):
        for i, seed in enumerate(seed_result):
            final_result_dev[i][model_id] = seed_best[seed][model_id][args.key]
            final_result_test[i][model_id] = seed_best[seed][model_id][
                args.test_key]
            if len(args.test_key2) > 0:
                final_result_test2[i][model_id] = seed_best[seed][model_id][
                    args.test_key2]

            logit_file_list[seed].append("{}-{}-{}.npy".format(
                condition['task_name'], model_id,
                seed_best[seed][model_id]["array_id"]))

        s = "Model %d | val: mean +- std: %.1f +- %.1f | test: mean +- std: %.1f (%.1f) (median %.1f)" % (
            model_id, final_result_dev[:, model_id].mean() * 100,
            final_result_dev[:, model_id].std() * 100,
            final_result_test[:, model_id].mean() * 100,
            final_result_test[:, model_id].std() * 100,
            np.median(final_result_test[:, model_id]) * 100)
        if len(args.test_key2) > 0:
            s += " / %.1f +- %.1f (median %.1f)" % (
                final_result_test2[:, model_id].mean() * 100,
                final_result_test2[:, model_id].std() * 100,
                np.median(final_result_test2[:, model_id]) * 100)
        print(s)

    # Map lower-case names to official names (data folder name)
    data_dir_mapping = {
        'cola': 'CoLA',
        'mrpc': 'MRPC',
        'qqp': 'QQP',
        'sts-b': 'STS-B',
        'sst-2': 'SST-2',
        'snli': 'SNLI',
        'mnli': 'MNLI',
        'mnli-mm': 'MNLI',
        'rte': 'RTE',
        'ag_news': 'ag_news',
        'yahoo_answers': 'yahoo_answers',
        'yelp_review_full': 'yelp_review_full',
        'sst-5': 'sst-5',
        'mr': 'mr',
        'cr': 'cr',
        'mpqa': 'mpqa',
        'subj': 'subj',
        'trec': 'trec'
    }

    tokenizer = AutoTokenizer.from_pretrained('roberta-large')
    ensemble_result = np.zeros((len(seed_result)))
    ensemble_result2 = np.zeros((len(seed_result)))  # for second metric

    # Ensemble for each seed
    for seed_id, seed in enumerate(seed_result):
        labels = get_labels(args.data_dir, args.k, seed,
                            condition['task_name'],
                            data_dir_mapping[condition['task_name']])

        # Logits
        mean_logits = None
        for fname in logit_file_list[seed]:
            logits = np.load(os.path.join(args.save_logit_dir, fname))
            if mean_logits is None:
                mean_logits = logits
            else:
                mean_logits += logits
        mean_logits /= len(logit_file_list[seed])

        # Compute metrics
        preds = mean_logits.argmax(-1)
        if condition['task_name'] in [
                'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec'
        ]:
            metric = {"acc": simple_accuracy(preds, labels)}
        else:
            metric = glue_compute_metrics(condition['task_name'], preds,
                                          labels)

        ensemble_result[seed_id] = metric[args.test_key.split('_')[-1]]
        if len(args.test_key2) > 0:
            ensemble_result2[seed_id] = metric[args.test_key2.split('_')[-1]]

    s = "mean +- std: %.1f (%.1f) (median %.1f)" % (
        ensemble_result.mean() * 100, ensemble_result.std() * 100,
        np.median(ensemble_result) * 100)
    if len(args.test_key2) > 0:
        s += " / %.1f (%.1f) (median %.1f)" % (
            ensemble_result2.mean() * 100, ensemble_result2.std() * 100,
            np.median(ensemble_result2) * 100)
    print(s)
Ejemplo n.º 20
0
def evaluate(model: TransformerModelWrapper,
             eval_data: List[InputExample],
             config: EvalConfig,
             priming_data: List[InputExample] = None) -> Dict:
    """
    Evaluate a model.

    :param model: the model to evaluate
    :param eval_data: the examples for evaluation
    :param config: the evaluation config
    :param priming_data: an optional list of priming data to use
    :return: a dictionary containing the model's logits, predictions and (if any metrics are given) scores
    """

    if config.priming:
        for example in eval_data:
            example.meta['priming_data'] = priming_data

    metrics = config.metrics if config.metrics else ['acc']
    device = torch.device(config.device if config.device else "cuda" if torch.
                          cuda.is_available() else "cpu")

    model.model.to(device)
    results = model.eval(
        eval_data,
        device,
        per_gpu_eval_batch_size=config.per_gpu_eval_batch_size,
        n_gpu=config.n_gpu,
        decoding_strategy=config.decoding_strategy,
        priming=config.priming)

    predictions = np.argmax(results['logits'], axis=1)
    scores = {}

    for metric in metrics:
        if metric == 'acc':
            scores[metric] = simple_accuracy(predictions, results['labels'])
        elif metric == 'f1':
            scores[metric] = f1_score(results['labels'], predictions)
        elif metric == 'f1-macro':
            scores[metric] = f1_score(results['labels'],
                                      predictions,
                                      average='macro')
        elif metric == 'em':
            scores[metric] = exact_match(predictions, results['labels'],
                                         results['question_ids'])
        elif metric == 'dist-loss':
            if eval_data[0].logits is not None:
                scores[metric] = distillation_loss(
                    torch.tensor(results['logits']),
                    torch.stack([
                        torch.tensor(ex.logits, dtype=torch.float32)
                        for ex in eval_data
                    ]), config.temperature)
            else:
                scores[metric] = 0.
        else:
            raise ValueError(f"Metric '{metric}' not implemented")

    results['scores'] = scores
    results['predictions'] = predictions
    return results
Ejemplo n.º 21
0
    def eval(self,
             eval_data: List[InputExample],
             device,
             per_gpu_eval_batch_size: int = 8,
             n_gpu: int = 1,
             output_logits: bool = False,
             **_):
        eval_dataset = self._generate_dataset(eval_data)
        eval_batch_size = per_gpu_eval_batch_size * max(1, n_gpu)
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=eval_batch_size)

        if n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        ### NEW ###
        losses = []
        loss_fct = nn.CrossEntropyLoss(reduction='none')
        ### NEW ###

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            self.model.eval()
            batch = tuple(t.to(device) for t in batch)
            labels = batch[3]
            mlm_labels = batch[4]

            with torch.no_grad():
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2]
                    if self.config.model_type in ['bert', 'xlnet'] else None
                }
                ### NEW ###
                if self.config.model_type == 'xlnet':
                    outputs = self.xlnet_forward(inputs, mlm_labels, device)
                else:
                    outputs = self.model(**inputs)
                ### NEW ###
                logits = outputs[0]
                ### NEW ###
                # logger.info(self.tokenizer.decode(inputs['input_ids'][0]))
                # for j in range(inputs['input_ids'].shape[0]):
                #     logger.info(self.tokenizer.decode(inputs['input_ids'][j]))
                # assert False
                ### NEW ###
                if self.config.wrapper_type == MLM_WRAPPER:
                    logits = self.preprocessor.pvp.convert_mlm_logits_to_cls_logits(
                        mlm_labels, logits)
                    ### NEW ###
                    loss = loss_fct(
                        logits.view(-1, len(self.config.label_list)),
                        labels.view(-1))
                    losses.extend(loss.tolist())
                    ### NEW ###
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids,
                                          labels.detach().cpu().numpy(),
                                          axis=0)

        if output_logits:
            return preds

        preds = np.argmax(preds, axis=1)
        ### NEW ###
        # logger.info(losses)
        # return {"acc": simple_accuracy(preds, out_label_ids),
        #         "loss": np.mean(losses)}
        return {
            "acc": simple_accuracy(preds, out_label_ids),
            "agreement": preds == out_label_ids
        }