Exemple #1
0
    def forward(
        self,
        input_ids,
        position_ids,
        attention_mask,
        beam_select_idx,
        input_log_probs,
        input_unfinished_sents,
        prev_step_results,
        prev_step_scores,
        *past,
    ):
        input_ids = input_ids.view(self.config.batch_size, -1,
                                   input_ids.size(-1))
        past = [
            past[i].index_select(1, beam_select_idx[0])
            for i in range(len(past))
        ]
        result = super().forward(
            input_ids.view(-1, input_ids.size(-1)),
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past,
            return_dict=False,
        )
        logits_flat, present_flat = MyGPT2Model.post_process(
            result, self.config.n_layer)
        next_token_logits = logits_flat[:, -1].view(self.config.batch_size, -1,
                                                    logits_flat.size(-1))
        next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1)
        next_token_log_probs, next_token_ids = torch.topk(
            next_token_log_probs,
            self.config.beam_size,
            dim=-1,
            largest=True,
            sorted=True)

        # finished sentences is always with EOS, and all but the first one has -inf, so that they will be automatically dropped in the round of beam search.
        finished_sents = ~input_unfinished_sents
        next_token_log_probs.masked_fill_(finished_sents.unsqueeze(-1),
                                          -numpy.inf)
        next_token_log_probs[..., 0].masked_fill_(finished_sents, 0)
        next_token_ids.masked_fill_(finished_sents.unsqueeze(-1),
                                    self.config.eos_token_id)
        output_log_probs = input_log_probs.unsqueeze(-1) + next_token_log_probs

        # select N sequences from beams of each input, sorted by sequence probability
        output_log_probs = output_log_probs.view(
            self.config.batch_size, -1)  # shape=(batch, beam_size^2)
        output_log_probs, selected_index_flat = output_log_probs.topk(
            self.config.beam_size, dim=-1, largest=True,
            sorted=True)  # output shape=(batch, beam_size)

        # select the correspondent sentences/next tokens
        selected_input_seq = selected_index_flat // self.config.beam_size
        next_token_ids = next_token_ids.view(self.config.batch_size,
                                             -1).gather(
                                                 -1, selected_index_flat)

        prev_step_results = prev_step_results.view(self.config.batch_size, -1,
                                                   prev_step_results.size(-1))
        prev_step_results = prev_step_results.gather(
            1,
            selected_input_seq.unsqueeze(-1).repeat(
                1, 1, prev_step_results.size(-1)))

        output_unfinished_sents = input_unfinished_sents.gather(
            1, selected_input_seq)
        output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne(
            self.config.eos_token_id))

        # get the next full input_ids
        current_step_results = torch.cat(
            [prev_step_results,
             next_token_ids.unsqueeze(-1)], dim=-1).contiguous()

        prev_step_scores = prev_step_scores.view(self.config.batch_size, -1,
                                                 prev_step_scores.size(-1))
        prev_step_scores = prev_step_scores.gather(
            1,
            selected_input_seq.unsqueeze(-1).repeat(1, 1,
                                                    prev_step_scores.size(-1)))
        current_step_scores = torch.cat(
            [prev_step_scores,
             output_log_probs.unsqueeze(-1)], dim=-1).contiguous()

        return (
            next_token_ids,
            present_flat,
            selected_input_seq,
            output_log_probs,
            output_unfinished_sents,
            current_step_results.view(
                self.config.batch_size * self.config.beam_size, -1),
            current_step_scores.view(
                self.config.batch_size * self.config.beam_size, -1),
        )
Exemple #2
0
    def forward(
        self,
        input_ids,
        beam_select_idx,
        input_log_probs,
        input_unfinished_sents,
        prev_step_scores,
        *past,
    ):
        input_ids = input_ids.view(self.config.batch_size, -1,
                                   input_ids.size(-1))
        input_num_seq_per_sample = input_ids.size(1)

        input_ids_unfinished_flat = self.collapse_first_two_dims(
            input_ids).index_select(
                0,
                input_unfinished_sents.view(-1).nonzero(
                    as_tuple=False).view(-1))

        if self.config.ignore_eos:
            attention_mask = (input_ids_unfinished_flat !=
                              self.config.eos_token_id).float()
        else:
            attention_mask = torch.ones(
                input_ids_unfinished_flat.shape).float().to(
                    input_ids_unfinished_flat.device)
        position_ids = (attention_mask.cumsum(-1) - 1).clamp(min=0).long()

        if past:
            last_seq_len = past[0].size(-2)
            input_ids_unfinished_flat = input_ids_unfinished_flat[:,
                                                                  last_seq_len:]
            position_ids = position_ids[:, last_seq_len:]

            unfinished_index_relative_to_last_unfinished = beam_select_idx.view(
                -1)[input_unfinished_sents.view(-1).nonzero(
                    as_tuple=False).view(-1)]

            past = tuple([
                p.index_select(1, unfinished_index_relative_to_last_unfinished)
                for p in past
            ])

        result = super().forward(
            input_ids_unfinished_flat.view(-1,
                                           input_ids_unfinished_flat.size(-1)),
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past,
            return_dict=False,
        )
        logits_flat, present_flat = MyGPT2Model.post_process(
            result, self.config.n_layer)

        # insert finished sequence back to form a square shape of (batch_size, beam_size)
        next_token_logits = logits_flat.new_zeros(input_ids.size()[:2] +
                                                  (logits_flat.size(-1), ))
        next_token_logits.index_fill_(
            2,
            torch.LongTensor([self.config.eos_token_id]).to(input_ids.device),
            -BIG_NEG)

        next_token_logits.masked_scatter_(
            input_unfinished_sents.unsqueeze(-1).expand_as(next_token_logits),
            logits_flat[:, -1])

        # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
        if self.config.repetition_penalty != 1.0:
            _pen = next_token_logits.gather(2, input_ids)
            _pen = torch.where(_pen > 0, _pen / self.config.repetition_penalty,
                               _pen * self.config.repetition_penalty)
            next_token_logits.scatter_(2, input_ids, _pen)

        # similar way to encourage short sentence
        if self.config.length_penalty != 1.0:
            _pen = next_token_logits[..., self.config.eos_token_id]
            # if eos > 0, increase it, else, decrease it.
            _pen = torch.where(_pen > 0, _pen * self.config.length_penalty,
                               _pen / self.config.length_penalty)
            next_token_logits[..., self.config.eos_token_id] = _pen

        if self.config.temperature != 1.0:
            next_token_logits = next_token_logits / self.config.temperature

        # exclude excluded_token_ids
        if self.config.excluded_token_ids is not None:
            next_token_logits.index_fill_(
                2, self.config.excluded_token_ids.to(next_token_logits.device),
                BIG_NEG)  # batch x beams/sequences x vocab_size

        next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1)

        if self.config.do_sample:
            vocab_size = next_token_log_probs.size(-1)
            _next_token_log_probs = self.top_k_top_p_filtering(
                next_token_log_probs.view(-1, vocab_size),
                top_k=self.config.do_sample_top_k,
                top_p=self.config.do_sample_top_p)
            next_token_ids = torch.multinomial(
                _next_token_log_probs.exp(),
                num_samples=self.config.beam_size,
                replacement=False)
            next_token_ids = next_token_ids.view(self.config.batch_size,
                                                 input_num_seq_per_sample, -1)
            next_token_log_probs = next_token_log_probs.gather(
                -1, next_token_ids)
        else:
            next_token_log_probs, next_token_ids = torch.topk(
                next_token_log_probs,
                self.config.beam_size,
                dim=-1,
                largest=True,
                sorted=True)

        output_log_probs = input_log_probs.unsqueeze(-1) + next_token_log_probs

        # select N sequences from beams of each input, sorted by sequence probability
        output_log_probs = output_log_probs.view(
            self.config.batch_size, -1)  # shape=(batch, beam_size^2)
        output_log_probs, selected_index_flat = output_log_probs.topk(
            self.config.beam_size, dim=-1, largest=True,
            sorted=True)  # output shape=(batch, beam_size)

        # select the correspondent sentences/next tokens
        selected_input_seq = selected_index_flat // self.config.beam_size
        next_token_ids = next_token_ids.view(self.config.batch_size,
                                             -1).gather(
                                                 -1, selected_index_flat)

        prev_step_results = input_ids.view(self.config.batch_size, -1,
                                           input_ids.size(-1)).contiguous()
        prev_step_results = prev_step_results.gather(
            1,
            selected_input_seq.unsqueeze(-1).expand(
                selected_input_seq.shape + (prev_step_results.size(-1), )))

        output_unfinished_sents = input_unfinished_sents.gather(
            1, selected_input_seq)
        output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne(
            self.config.eos_token_id))

        current_step_results = torch.cat(
            [prev_step_results,
             next_token_ids.unsqueeze(-1)], dim=-1).contiguous()

        prev_step_scores = prev_step_scores.view(self.config.batch_size, -1,
                                                 prev_step_scores.size(-1))
        prev_step_scores = prev_step_scores.gather(
            1,
            selected_input_seq.unsqueeze(-1).expand(
                selected_input_seq.shape + (prev_step_scores.size(-1), )))
        current_step_scores = torch.cat(
            [prev_step_scores,
             output_log_probs.unsqueeze(-1)], dim=-1).contiguous()

        # For next past state
        index_relative_to_last_unfinished = (
            input_unfinished_sents.view(-1).float().cumsum(-1) -
            1).clamp(min=0).long().reshape_as(input_unfinished_sents).gather(
                1, selected_input_seq)

        return (
            current_step_results.view(
                self.config.batch_size * self.config.beam_size, -1),
            present_flat,
            index_relative_to_last_unfinished,
            output_log_probs,
            output_unfinished_sents,
            current_step_scores.view(
                self.config.batch_size * self.config.beam_size, -1),
        )