def create_and_check_bert_for_multiple_choice(
         self, config, input_ids, token_type_ids, input_mask,
         sequence_labels, token_labels, choice_labels):
     config.num_choices = self.num_choices
     model = BertForMultipleChoice(config=config)
     model.to(torch_device)
     model.eval()
     multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(
         -1, self.num_choices, -1).contiguous()
     multiple_choice_token_type_ids = token_type_ids.unsqueeze(
         1).expand(-1, self.num_choices, -1).contiguous()
     multiple_choice_input_mask = input_mask.unsqueeze(1).expand(
         -1, self.num_choices, -1).contiguous()
     loss, logits = model(
         multiple_choice_inputs_ids,
         attention_mask=multiple_choice_input_mask,
         token_type_ids=multiple_choice_token_type_ids,
         labels=choice_labels,
     )
     result = {
         "loss": loss,
         "logits": logits,
     }
     self.parent.assertListEqual(list(result["logits"].size()),
                                 [self.batch_size, self.num_choices])
     self.check_loss_output(result)
def evaluate(classifier_model: BertForMultipleChoice, dataloader: DataLoader,
             device: torch.device):
    """
    モデルの評価を行う。
    結果やラベルはDict形式で返される。
    """
    classifier_model.eval()

    count_steps = 0
    total_loss = 0

    preds = None
    correct_labels = None
    for batch_idx, batch in enumerate(tqdm(dataloader)):
        with torch.no_grad():
            batch = tuple(t for t in batch)
            bert_inputs = {
                "input_ids": batch[0].to(device),
                "attention_mask": batch[1].to(device),
                "token_type_ids": batch[2].to(device),
                "labels": batch[3].to(device)
            }

            classifier_outputs = classifier_model(**bert_inputs)
            loss, logits = classifier_outputs[:2]

            count_steps += 1
            total_loss += loss.item()

            if preds is None:
                preds = logits.detach().cpu().numpy()
                correct_labels = bert_inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                correct_labels = np.append(
                    correct_labels,
                    bert_inputs["labels"].detach().cpu().numpy(),
                    axis=0)

    pred_labels = np.argmax(preds, axis=1)
    accuracy = calc_accuracy(pred_labels, correct_labels)
    eval_loss = total_loss / count_steps

    ret = {
        "pred_labels": pred_labels,
        "correct_labels": correct_labels,
        "logits": preds,
        "accuracy": accuracy,
        "eval_loss": eval_loss
    }

    return ret
 def create_and_check_for_multiple_choice(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     config.num_choices = self.num_choices
     model = BertForMultipleChoice(config=config)
     model.to(torch_device)
     model.eval()
     multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
     multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
     multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
     result = model(
         multiple_choice_inputs_ids,
         attention_mask=multiple_choice_input_mask,
         token_type_ids=multiple_choice_token_type_ids,
         labels=choice_labels,
     )
     self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))