Esempio n. 1
0
    def eval_step(self,
                  batch: Dict[str, torch.Tensor],
                  batch_size: int = 8,
                  decoding_strategy: str = 'default'):
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for MultiMaskTaskHelper is only implemented for MLM models'
        assert batch['input_ids'].shape[
            0] == 1, "eval_step() for MultiMaskTaskHelper is only implemented for batch_size=1"

        all_choice_token_ids = batch['choice_token_ids'][0]
        log_probabilities = torch.tensor(
            [[-math.inf] * len(all_choice_token_ids)],
            dtype=torch.float,
            device=all_choice_token_ids.device)

        # group choices by length to speed up decoding
        choices_grouped_by_length = defaultdict(list)

        for idx, choice_token_ids in enumerate(all_choice_token_ids):
            num_masks = sum(1 for x in choice_token_ids if x != -100)
            choices_grouped_by_length[num_masks].append(
                (idx, choice_token_ids))

        input_ids = {}
        initial_outputs = {}

        for num_masks in choices_grouped_by_length.keys():
            # modify the input ids to contain the correct number of masks
            input_ids[num_masks] = trim_input_ids(
                batch['input_ids'],
                num_masks=num_masks,
                pad_token_id=self.wrapper.tokenizer.pad_token_id,
                mask_token_id=self.wrapper.tokenizer.mask_token_id)

            initial_outputs[num_masks] = self.wrapper.model(
                input_ids[num_masks])

        for num_masks, choices_with_labels in choices_grouped_by_length.items(
        ):

            for batch in chunks(choices_with_labels, batch_size):
                batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
                choice_token_ids = torch.stack(
                    [choice_token_ids for idx, choice_token_ids in batch])

                batch_probabilities = self._get_choice_probabilities_batched(
                    choice_token_ids,
                    batch_input_ids,
                    initial_outputs[num_masks],
                    decoding_strategy=decoding_strategy)

                for batch_idx, (idx, choice_token_ids) in enumerate(batch):
                    log_probabilities[0][idx] = batch_probabilities[batch_idx]

        return log_probabilities
Esempio n. 2
0
    def eval_step(self, batch: Dict[str, torch.Tensor], batch_size: int = 8, decoding_strategy: str = 'default'):
        assert self.wrapper.config.wrapper_type == 'mlm', 'eval_step() for ReCoRD is only implemented for MLM models'
        assert batch['input_ids'].shape[0] == 1, "eval_step() for ReCoRD is only implemented for batch_size=1"

        best_choice_correct, best_choice, max_prob = False, None, None
        question_idx = batch['question_idx'][0].item()
        output_line = {'idx': question_idx, 'choices': {}}

        # group choices by length to speed up decoding
        choices_grouped_by_length = defaultdict(list)

        for idx, (choice_ids, label) in enumerate(zip(batch['candidate_token_ids'][0], batch['candidate_labels'][0])):
            if label < 0:
                continue
            num_masks = sum(1 for x in choice_ids if x != -100)
            choice = self.original_choices[question_idx][idx]
            choices_grouped_by_length[num_masks].append((choice, choice_ids, label))

        input_ids = {}
        initial_outputs = {}

        for num_masks in choices_grouped_by_length.keys():
            # modify the input ids to contain the correct number of masks
            input_ids[num_masks] = trim_input_ids(batch['input_ids'], num_masks=num_masks,
                                                  pad_token_id=self.wrapper.tokenizer.pad_token_id,
                                                  mask_token_id=self.wrapper.tokenizer.mask_token_id)

            initial_outputs[num_masks] = self.wrapper.model(input_ids[num_masks])

        for num_masks, choices_with_labels in choices_grouped_by_length.items():

            for batch in chunks(choices_with_labels, batch_size):
                batch_input_ids = input_ids[num_masks].repeat(len(batch), 1)
                choice_ids = torch.stack([choice_id for choice, choice_id, label in batch])

                probs = self._get_choice_probabilities_batched(choice_ids, batch_input_ids, initial_outputs[num_masks],
                                                               decoding_strategy=decoding_strategy)

                for idx, (choice, choice_ids, label) in enumerate(batch):
                    prob = probs[idx]
                    output_line['choices'][choice] = prob

                    if max_prob is None or prob > max_prob:
                        best_choice_correct, max_prob = (label == 1), prob

        self.output.append(output_line)

        if best_choice_correct:
            return torch.tensor([[0, 1]])
        return torch.tensor([[1, 0]])