Exemplo n.º 1
0
def get_evaluator(args, model, loss_fn):
    def _inference(evaluator, batch):
        model.eval()
        with torch.no_grad():
            qids = batch["qid"]
            net_inputs, _ = prepare_batch(args, batch, model.vocab)
            y_pred = model(**net_inputs)
            y_pred = y_pred.argmax(dim=-1)  # + 1  # 0~4 -> 1~5

            for qid, ans_idx in zip(qids, y_pred):
                engine.answers[qid] = ans_idx.item()

            return

    engine = Engine(_inference)
    engine.answers = {}

    return engine
Exemplo n.º 2
0
def get_evaluator(args, model, loss_fn):
    def _inference(evaluator, batch):
        model.eval()
        with torch.no_grad():
            qids = batch["qid"]
            net_inputs, _ = prepare_batch(args, batch, model.module.vocab)
            y_pred, char_pred, mask_pred = model(**net_inputs)
            print("Before argmax:", y_pred.size())
            y_pred = y_pred.argmax(dim=-1)  # + 1  # 0~4 -> 1~5
            print("After argmax:", y_pred.size())
            for qid, ans in zip(qids, y_pred):
                engine.answers[qid] = ans.item()

            return

    engine = Engine(_inference)
    engine.answers = {}

    return engine
def get_evaluator(args, model, loss_fn):
    def _inference(evaluator, batch):
        model.eval()
        with torch.no_grad():
            new_qids = batch["new_qid"]
            real_ans = batch["correct_idx"]
            net_inputs, _ = prepare_batch(args, batch, model.vocab)

            y_pred = model(**net_inputs)
            y_pred = y_pred.argmax(dim=-1)  # + 1  # 0~4 -> 1~5

            for new_qid, ans, r_ans in zip(new_qids, y_pred, real_ans):
                engine.answers[new_qid] = ans.item()
                if ans.item() == r_ans:
                    engine.count += 1
            return

    engine = Engine(_inference)
    engine.answers = {}
    engine.count = 0
    return engine