def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        min_length,
        do_sample,
        early_stopping,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        no_repeat_ngram_size,
        bad_words_ids,
        pad_token_id,
        eos_token_id,
        batch_size,
        num_return_sequences,
        length_penalty,
        num_beams,
        vocab_size,
        encoder_outputs,
        attention_mask,
        use_cache,
        model_specific_kwargs,
    ):
        """Generate sequences for each example with beam search."""
        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(num_beams,
                           max_length,
                           length_penalty,
                           early_stopping=early_stopping)
            for _ in range(batch_size)
        ]

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams),
                                  dtype=torch.float,
                                  device=input_ids.device)

        # for greedy decoding it is made sure that only tokens of the first
        # beam are considered to avoid sampling the exact same tokens three
        # times
        if do_sample is False:
            beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
        beams_offset = (torch.arange(0, batch_size) *
                        num_beams).unsqueeze(1).type_as(input_ids)

        cand_size = 2 * num_beams
        cand_offsets = torch.arange(0, cand_size).type_as(input_ids)

        # cache compute states
        past = (encoder_outputs, None) if encoder_outputs is not None else None

        # done sentences
        done = [False for _ in range(batch_size)]
        """
        _reorder_cache_v2(past, batch_idxs, beam_idxs)
        Remove the finished batches during beam search, reorder_cache_v2 is used to support dynamic batch size.
        for cache tensors with shape like (batch_size, ~): tensor = tensor[batch_idxs]
        for cache tensors with shape like (batch_size * beam_size, ~): tensor = tensor[beam_idxs]
        """
        use_reorder_cache_v2 = hasattr(self, '_reorder_cache_v2')

        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(
                input_ids,
                past=past,
                attention_mask=attention_mask,
                use_cache=use_cache,
                **model_specific_kwargs)
            # (batch_size * num_beams, cur_len, vocab_size)
            outputs = self(**model_inputs)
            # (batch_size * num_beams, vocab_size)
            next_token_logits = outputs[0][:, -1, :]

            # if model has past, then set the past variable to speed up decoding
            if self._use_cache(outputs, use_cache):
                past = outputs[1]
            if self.config.is_encoder_decoder and do_sample is False:
                # TODO (PVP) still a bit hacky here - there might be a better
                # solution
                next_token_logits = self.adjust_logits_during_generation(
                    next_token_logits, cur_len=cur_len, max_length=max_length)
            # (batch_size * num_beams, vocab_size)
            scores = F.log_softmax(next_token_logits, dim=-1)
            scores = self.postprocess_next_token_scores(
                scores=scores,
                input_ids=input_ids,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bad_words_ids=bad_words_ids,
                cur_len=cur_len,
                min_length=min_length,
                max_length=max_length,
                eos_token_id=eos_token_id,
                repetition_penalty=repetition_penalty,
                batch_size=batch_size,
                num_beams=num_beams,
            )

            assert scores.shape == (batch_size * num_beams, vocab_size),\
                 "Shapes of scores: {} != {}".format(
                scores.shape, (batch_size * num_beams, vocab_size)
            )

            if do_sample:
                # (batch_size * num_beams, vocab_size)
                curr_scores = scores + beam_scores[:, None].expand_as(scores)
                # Temperature
                if temperature != 1.0:
                    curr_scores = curr_scores / temperature
                # Top-p/top-k filtering
                curr_scores = top_k_top_p_filtering(
                    curr_scores,
                    top_k=top_k,
                    top_p=top_p,
                    min_tokens_to_keep=2
                )  # (batch_size * num_beams, vocab_size)
                # re-organize to group the beam together to sample from all
                # beam_idxs
                curr_scores = curr_scores.contiguous().view(
                    batch_size, num_beams *
                    vocab_size)  # (batch_size, num_beams * vocab_size)

                # Sample 2 next tokens for each beam (so we have some spare
                # tokens and match output of greedy beam search)
                probs = F.softmax(curr_scores, dim=-1)

                # (batch_size, num_beams * 2)
                next_tokens = torch.multinomial(probs,
                                                num_samples=2 * num_beams)

                # Compute next scores
                # (batch_size, num_beams * 2)
                next_scores = torch.gather(curr_scores, -1, next_tokens)
                # sort the sampled vector to make sure that the first num_beams
                # samples are the best
                next_scores, next_scores_indices = torch.sort(next_scores,
                                                              descending=True,
                                                              dim=1)
                # (batch_size, num_beams * 2)
                next_tokens = torch.gather(next_tokens, -1,
                                           next_scores_indices)

            else:
                # (batch_size * num_beams, vocab_size)
                next_scores = scores + beam_scores[:, None].expand_as(scores)

                # re-organize to group the beam together (we are keeping top
                # hypothesis accross beams)
                next_scores = next_scores.view(
                    batch_size, num_beams *
                    vocab_size)  # (batch_size, num_beams * vocab_size)

                next_scores, next_tokens = torch.topk(next_scores,
                                                      2 * num_beams,
                                                      dim=1,
                                                      largest=True,
                                                      sorted=True)

            assert next_scores.size() == next_tokens.size() == (batch_size,
                                                                2 * num_beams)
            # next batch beam content
            next_tokens_id = next_tokens % vocab_size
            next_beams_id = next_tokens // vocab_size
            effective_beam_id = next_beams_id + beams_offset
            if eos_token_id is not None:
                eos_mask = next_tokens_id.eq(eos_token_id)
            else:
                eos_mask = torch.zeros_like(next_tokens_id).bool()
            eos_effective_idx = torch.masked_select(
                effective_beam_id[:, :num_beams], mask=eos_mask[:, :num_beams])

            finished_batch_idxs = []
            if use_reorder_cache_v2 and eos_effective_idx.numel() > 0:
                eos_effective_scores = torch.masked_select(
                    next_scores[:, :num_beams], mask=eos_mask[:, :num_beams])
                input_clone = input_ids.index_select(0, eos_effective_idx)
                unfin_offset = np.array(list(
                    accumulate(done)))[np.array(done) == 0]
                for i in range(eos_effective_idx.size(0)):
                    eos_idx = eos_effective_idx[i]
                    eos_score = eos_effective_scores[i]
                    unfin_batch_idx = eos_idx // num_beams
                    batch_idx = unfin_batch_idx + unfin_offset[unfin_batch_idx]
                    if not done[batch_idx]:
                        generated_hyps[batch_idx.item()].add(
                            input_clone[i], eos_score.item())
                    is_done = done[batch_idx]
                    done[batch_idx] = (
                        done[batch_idx] or generated_hyps[batch_idx].is_done(
                            next_scores[unfin_batch_idx].max().item(),
                            cur_len))
                    if is_done != done[batch_idx]:
                        finished_batch_idxs.append(unfin_batch_idx)

            if not use_reorder_cache_v2:
                eos_effective_scores = torch.masked_select(
                    next_scores[:, :num_beams], mask=eos_mask[:, :num_beams])
                input_ids_cpu = input_ids.cpu()
                eos_effective_idx_cpu = eos_effective_idx.cpu()
                eos_effective_scores_cpu = eos_effective_scores.cpu()
                for i in range(0, eos_effective_idx_cpu.size()[-1]):
                    batch_idx = eos_effective_idx_cpu[i] // num_beams
                    if not done[batch_idx]:
                        generated_hyps[batch_idx.item()].add(
                            input_ids_cpu[eos_effective_idx_cpu[i]].clone(),
                            eos_effective_scores_cpu[i])
                    done[batch_idx] = (done[batch_idx]
                                       or generated_hyps[batch_idx].is_done(
                                           next_scores[batch_idx].max().item(),
                                           cur_len))

            if all(done):
                break

            if use_reorder_cache_v2 and len(finished_batch_idxs) > 0:
                new_batch_size = batch_size - len(finished_batch_idxs)
                batch_mask = torch.ones(batch_size).to(next_tokens_id)
                batch_mask[torch.tensor(finished_batch_idxs)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)
                eos_mask = eos_mask[batch_idxs]
                next_beams_id = next_beams_id[batch_idxs]
                beams_offset.resize_(new_batch_size, 1)
                effective_beam_id = next_beams_id.add(beams_offset)
                next_scores = next_scores[batch_idxs]
                next_tokens = next_tokens[batch_idxs]
                next_tokens_id = next_tokens_id[batch_idxs]
                input_ids = input_ids.view(batch_size, -1)[batch_idxs].view(
                    new_batch_size * num_beams, -1)
                before_batch_size = batch_size
                batch_size = new_batch_size
            else:
                before_batch_size = batch_size
                batch_idxs = None

            active_mask = torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)])
            _, active_hypos = torch.topk(active_mask,
                                         k=num_beams,
                                         dim=1,
                                         largest=False)
            active_effective_beam_id = torch.gather(effective_beam_id,
                                                    dim=1,
                                                    index=active_hypos)
            active_scores = torch.gather(next_scores,
                                         dim=1,
                                         index=active_hypos)
            active_tokens = torch.gather(next_tokens_id,
                                         dim=1,
                                         index=active_hypos)
            beam_idxs = active_effective_beam_id.view(-1)
            beam_scores = active_scores.view(-1)
            beam_tokens = active_tokens.view(-1)

            # re-order batch and update current length
            input_ids = input_ids[beam_idxs, :]
            input_ids = torch.cat(
                [input_ids, beam_tokens.unsqueeze(1)], dim=-1)
            cur_len = cur_len + 1

            # re-order internal states
            if past is not None:
                if use_reorder_cache_v2:
                    new_beam_idxs = torch.arange(
                        before_batch_size * num_beams).reshape(
                            before_batch_size, num_beams).to(input_ids)
                    beam_idxs = new_beam_idxs[batch_idxs].reshape(
                        -1)[beam_idxs]
                    past = self._reorder_cache_v2(past, batch_idxs, beam_idxs)
                else:
                    past = self._reorder_cache(past, beam_idxs)

            # extend attention_mask for new generated input if only decoder
            if self.config.is_encoder_decoder is False:
                attention_mask = torch.cat([
                    attention_mask,
                    attention_mask.new_ones((attention_mask.shape[0], 1))
                ],
                                           dim=-1)

        # finalize all open beam hypotheses and add to generated hypotheses
        unfin_offset = np.array(list(accumulate(done)))[np.array(done) == 0]
        if use_reorder_cache_v2:
            batch_size = len(unfin_offset)
        for batch_idx in range(batch_size):
            if not use_reorder_cache_v2 and done[batch_idx]:
                continue
            # test that beam scores match previously calculated scores if not
            # eos and batch_idx not done
            if eos_token_id is not None and all(
                (token_id % vocab_size).item() != eos_token_id
                    for token_id in next_tokens[batch_idx]):
                assert torch.all(
                    next_scores[batch_idx, :num_beams] == beam_scores.view(
                        batch_size, num_beams)[batch_idx]
                ), "If batch_idx is not done, final next scores: \
                {} have to equal to accumulated beam_scores: {}".format(
                    next_scores[:, :num_beams][batch_idx],
                    beam_scores.view(batch_size, num_beams)[batch_idx],
                )

            if use_reorder_cache_v2:
                final_batch_idx = batch_idx + unfin_offset[batch_idx]
            else:
                final_batch_idx = batch_idx
            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[final_batch_idx].add(final_tokens, final_score)

        batch_size = len(done)

        # depending on whether greedy generation is wanted or not define
        # different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample \
            else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 \
            if do_sample else num_return_sequences

        # select the best hypotheses
        sent_lengths = input_ids.new(output_batch_size)
        best = []

        # retrieve best hypotheses
        for i, hypotheses in enumerate(generated_hyps):
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                effective_batch_idx = \
                    output_num_return_sequences_per_batch * i + j
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[effective_batch_idx] = len(best_hyp)
                best.append(best_hyp)

        # shorter batches are padded
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, \
                "`Pad_token_id` has to be defined"
            sent_max_len = min(sent_lengths.max().item() + 1, max_length)
            decoded = input_ids.new(output_batch_size,
                                    sent_max_len).fill_(pad_token_id)

            # fill with hypothesis and eos_token_id if necessary
            for i, hypo in enumerate(best):
                decoded[i, :sent_lengths[i]] = hypo
                if sent_lengths[i] < max_length:
                    decoded[i, sent_lengths[i]] = eos_token_id
        else:
            # none of the hypotheses have an eos_token
            assert (len(hypo) == max_length for hypo in best)
            decoded = torch.stack(best).type(torch.long)\
                    .to(next(self.parameters()).device)

        return decoded
Example #2
0
    def generate(self,
                 prompt: Union[str, List[str]],
                 max_len: int = 20,
                 sample: bool = True,
                 k: int = 0,
                 p: float = 0.9,
                 temperature: float = 1.0,
                 bad_words_ids: List[List[int]] = None,
                 **model_kwargs) -> List[str]:
        if isinstance(prompt, str):
            prompt = [prompt]

        encodings_dict = self.tokenizer.batch_encode_plus(prompt, pad_to_max_length=True, return_tensors='pt')

        input_ids = encodings_dict['input_ids'].to(self.device)
        attention_mask = encodings_dict['attention_mask'].to(self.device)
        batch_size, input_seq_len = input_ids.shape

        position_ids = attention_mask.cumsum(dim=1) - 1
        unfinished_sents = torch.ones(batch_size, dtype=torch.long, device=self.device)

        self.model.eval()
        with torch.no_grad():
            for step in range(max_len):
                logits, past = self.model(input_ids, attention_mask=attention_mask, position_ids=position_ids,
                                          **model_kwargs)

                # in the first decoding step, we want to use the 'real' last position for each sentence
                if step == 0:
                    last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1
                    next_token_logits = logits[range(batch_size), last_non_masked_idx, :]
                else:
                    next_token_logits = logits[:, -1, :]

                if bad_words_ids is not None:
                    # calculate a list of banned tokens according to bad words
                    banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

                    # TODO: use a vectorized operation
                    for batch_idx in range(batch_size):
                        next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

                if sample:
                    # Temperature (higher temperature => more likely to sample low probability tokens)
                    if temperature != 1.0:
                        next_token_logits = next_token_logits / temperature
                    # Top-p/top-k filtering
                    next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=k, top_p=p)
                    # Sample
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                else:
                    # Greedy decoding
                    next_tokens = torch.argmax(next_token_logits, dim=-1)

                # either append a padding token here if <EOS> has been seen or append next token
                tokens_to_add = next_tokens * unfinished_sents + self.tokenizer.pad_token_id * (1 - unfinished_sents)

                # this updates which sentences have not seen an EOS token so far
                # if one EOS token was seen the sentence is finished
                eos_in_sents = tokens_to_add == self.tokenizer.eos_token_id
                unfinished_sents.mul_((~eos_in_sents).long())

                # stop when there is an EOS in each sentence
                if unfinished_sents.max() == 0:
                    break

                # Update input_ids, attention_mask and position_ids
                input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
                attention_mask = torch.cat([attention_mask, attention_mask.new_ones((batch_size, 1))], dim=1)
                position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

        decoded_outputs = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                           for output in input_ids[:, input_seq_len:]]
        return decoded_outputs
Example #3
0
    def _generate_no_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        min_length,
        do_sample,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        no_repeat_ngram_size,
        bad_words_ids,
        pad_token_id,
        eos_token_id,
        batch_size,
        attention_mask,
        use_cache,
        model_kwargs,
    ):
        """Generate sequences for each example without beam search (num_beams == 1).
        All returned sequence are generated independantly.
        """
        # length of generated sentences / unfinished sentences
        unfinished_sents = input_ids.new(batch_size).fill_(1)
        sent_lengths = input_ids.new(batch_size).fill_(max_length)

        past = None
        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(
                input_ids,
                past=past,
                attention_mask=attention_mask,
                use_cache=use_cache,
                **model_kwargs)

            outputs = self(**model_inputs, return_dict=True)
            next_token_logits = outputs.logits[:, -1, :]
            scores = self.postprocess_next_token_scores(
                scores=next_token_logits,
                input_ids=input_ids,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bad_words_ids=bad_words_ids,
                cur_len=cur_len,
                min_length=min_length,
                max_length=max_length,
                eos_token_id=eos_token_id,
                repetition_penalty=repetition_penalty,
                batch_size=batch_size,
                num_beams=1,
            )

            # if model has past, then set the past variable to speed up decoding
            if "past_key_values" in outputs:
                past = outputs.past_key_values
            elif "mems" in outputs:
                past = outputs.mems

            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    scores = scores / temperature
                # Top-p/top-k filtering
                next_token_logscores = top_k_top_p_filtering(scores,
                                                             top_k=top_k,
                                                             top_p=top_p)
                # Sample
                probs = F.softmax(next_token_logscores, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                # Greedy decoding
                next_token = torch.argmax(next_token_logits, dim=-1)

                # print(next_token_logits[0,next_token[0]], next_token_logits[0,eos_token_id])

            # update generations and finished sentences
            if eos_token_id is not None:
                # pad finished sentences if eos_token_id exist
                tokens_to_add = next_token * unfinished_sents + (
                    pad_token_id) * (1 - unfinished_sents)
            else:
                tokens_to_add = next_token

            # add token and increase length by one
            input_ids = torch.cat(
                [input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
            cur_len = cur_len + 1

            if eos_token_id is not None:
                eos_in_sents = tokens_to_add == eos_token_id
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(
                    eos_in_sents.long()).bool()
                sent_lengths.masked_fill_(
                    is_sents_unfinished_and_token_to_add_is_eos, cur_len)
                # unfinished_sents is set to zero if eos in sentence
                unfinished_sents.mul_((~eos_in_sents).long())

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

            # extend attention_mask for new generated input if only decoder
            # if self.config.is_encoder_decoder is False:
            #     attention_mask = torch.cat(
            #         [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            #     )

        return input_ids
Example #4
0
    def _generate_beam_search(
        self,
        input_ids,
        cur_len,
        max_length,
        min_length,
        do_sample,
        early_stopping,
        temperature,
        top_k,
        top_p,
        repetition_penalty,
        no_repeat_ngram_size,
        bad_words_ids,
        pad_token_id,
        eos_token_id,
        batch_size,
        num_return_sequences,
        length_penalty,
        num_beams,
        vocab_size,
        attention_mask,
        use_cache,
        model_kwargs,
    ):
        """Generate sequences for each example with beam search."""

        # generated hypotheses
        generated_hyps = [
            BeamHypotheses(num_beams,
                           max_length,
                           length_penalty,
                           early_stopping=early_stopping)
            for _ in range(batch_size)
        ]

        # scores for each sentence in the beam
        beam_scores = torch.zeros((batch_size, num_beams),
                                  dtype=torch.float,
                                  device=input_ids.device)

        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
        if do_sample is False:
            beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)

        # cache compute states
        past = None

        # done sentences
        done = [False for _ in range(batch_size)]

        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(
                input_ids,
                past=past,
                attention_mask=attention_mask,
                use_cache=use_cache,
                **model_kwargs)
            outputs = self(**model_inputs, return_dict=True
                           )  # (batch_size * num_beams, cur_len, vocab_size)
            next_token_logits = outputs.logits[:,
                                               -1, :]  # (batch_size * num_beams, vocab_size)

            # if model has past, then set the past variable to speed up decoding
            if "past_key_values" in outputs:
                past = outputs.past_key_values
            elif "mems" in outputs:
                past = outputs.mems

            if self.config.is_encoder_decoder and do_sample is False:
                # TODO (PVP) still a bit hacky here - there might be a better solution
                next_token_logits = self.adjust_logits_during_generation(
                    next_token_logits, cur_len=cur_len, max_length=max_length)

            scores = F.log_softmax(
                next_token_logits,
                dim=-1)  # (batch_size * num_beams, vocab_size)

            scores = self.postprocess_next_token_scores(
                scores=scores,
                input_ids=input_ids,
                no_repeat_ngram_size=no_repeat_ngram_size,
                bad_words_ids=bad_words_ids,
                cur_len=cur_len,
                min_length=min_length,
                max_length=max_length,
                eos_token_id=eos_token_id,
                repetition_penalty=repetition_penalty,
                batch_size=batch_size,
                num_beams=num_beams,
            )

            assert scores.shape == (
                batch_size * num_beams,
                vocab_size), "Shapes of scores: {} != {}".format(
                    scores.shape, (batch_size * num_beams, vocab_size))

            if do_sample:
                _scores = scores + beam_scores[:, None].expand_as(
                    scores)  # (batch_size * num_beams, vocab_size)
                # Temperature
                if temperature != 1.0:
                    _scores = _scores / temperature
                # Top-p/top-k filtering
                _scores = top_k_top_p_filtering(
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
                )  # (batch_size * num_beams, vocab_size)
                # re-organize to group the beam together to sample from all beam_idxs
                _scores = _scores.contiguous().view(
                    batch_size, num_beams *
                    vocab_size)  # (batch_size, num_beams * vocab_size)

                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
                probs = F.softmax(_scores, dim=-1)
                next_tokens = torch.multinomial(
                    probs,
                    num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
                # Compute next scores
                next_scores = torch.gather(
                    _scores, -1, next_tokens)  # (batch_size, num_beams * 2)
                # sort the sampled vector to make sure that the first num_beams samples are the best
                next_scores, next_scores_indices = torch.sort(next_scores,
                                                              descending=True,
                                                              dim=1)
                next_tokens = torch.gather(
                    next_tokens, -1,
                    next_scores_indices)  # (batch_size, num_beams * 2)

            else:
                next_scores = scores + beam_scores[:, None].expand_as(
                    scores)  # (batch_size * num_beams, vocab_size)

                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
                next_scores = next_scores.view(
                    batch_size, num_beams *
                    vocab_size)  # (batch_size, num_beams * vocab_size)

                next_scores, next_tokens = torch.topk(next_scores,
                                                      2 * num_beams,
                                                      dim=1,
                                                      largest=True,
                                                      sorted=True)

            assert next_scores.size() == next_tokens.size() == (batch_size,
                                                                2 * num_beams)

            # next batch beam content
            next_batch_beam = []

            # for each sentence
            for batch_idx in range(batch_size):

                # if we are done with this sentence, add a pad token
                if done[batch_idx]:
                    assert (
                        len(generated_hyps[batch_idx]) >= num_beams
                    ), "Batch can only be done if at least {} beams have been generated".format(
                        num_beams)
                    assert (
                        eos_token_id is not None and pad_token_id is not None
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                    next_batch_beam.extend([(0, pad_token_id, 0)] *
                                           num_beams)  # pad the batch
                    continue

                # next sentence beam content, this will get added to next_batch_beam
                next_sent_beam = []

                # next tokens for this sentence
                for beam_token_rank, (beam_token_id,
                                      beam_token_score) in enumerate(
                                          zip(next_tokens[batch_idx],
                                              next_scores[batch_idx])):
                    # get beam and token IDs
                    beam_id = beam_token_id // vocab_size
                    token_id = beam_token_id % vocab_size

                    effective_beam_id = batch_idx * num_beams + beam_id
                    # add to generated hypotheses if end of sentence
                    if (eos_token_id is not None) and (token_id.item()
                                                       == eos_token_id):
                        # if beam_token does not belong to top num_beams tokens, it should not be added
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                        if is_beam_token_worse_than_top_num_beams:
                            continue
                        generated_hyps[batch_idx].add(
                            input_ids[effective_beam_id].clone(),
                            beam_token_score.item(),
                        )
                    else:
                        # add next predicted token since it is not eos_token
                        next_sent_beam.append(
                            (beam_token_score, token_id, effective_beam_id))

                    # once the beam for next step is full, don't add more tokens to it.
                    if len(next_sent_beam) == num_beams:
                        break

                # Check if we are done so that we can save a pad step if all(done)
                done[batch_idx] = done[
                    batch_idx] or generated_hyps[batch_idx].is_done(
                        next_scores[batch_idx].max().item(), cur_len)

                # update next beam content
                assert len(
                    next_sent_beam) == num_beams, "Beam should always be full"
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == num_beams * (
                    batch_idx + 1), "We should have added num_beams each step"

            # stop when we are done with each sentence
            if all(done):
                break

            # sanity check / prepare next batch
            assert len(next_batch_beam) == batch_size * num_beams
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])

            # re-order batch and update current length
            input_ids = input_ids[beam_idx, :]
            input_ids = torch.cat(
                [input_ids, beam_tokens.unsqueeze(1)], dim=-1)
            cur_len = cur_len + 1

            # re-order internal states
            if past is not None:
                past = self._reorder_cache(past, beam_idx)

            # extend attention_mask for new generated input if only decoder
            # (huxu): move out since we trim attention_mask by ourselves.
            # if self.config.is_encoder_decoder is False:
            #    attention_mask = torch.cat(
            #        [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            #    )

        # finalize all open beam hypotheses and add to generated hypotheses
        for batch_idx in range(batch_size):
            if done[batch_idx]:
                continue

            # test that beam scores match previously calculated scores if not eos and batch_idx not done
            if eos_token_id is not None and all(
                (token_id % vocab_size).item() != eos_token_id
                    for token_id in next_tokens[batch_idx]):
                assert torch.all(
                    next_scores[batch_idx, :num_beams] == beam_scores.view(
                        batch_size, num_beams)[batch_idx]
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
                    next_scores[:, :num_beams][batch_idx],
                    beam_scores.view(batch_size, num_beams)[batch_idx],
                )

            # need to add best num_beams hypotheses to generated hyps
            for beam_id in range(num_beams):
                effective_beam_id = batch_idx * num_beams + beam_id
                final_score = beam_scores[effective_beam_id].item()
                final_tokens = input_ids[effective_beam_id]
                generated_hyps[batch_idx].add(final_tokens, final_score)

        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

        # select the best hypotheses
        sent_lengths = input_ids.new(output_batch_size)
        best = []

        # retrieve best hypotheses
        for i, hypotheses in enumerate(generated_hyps):
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
            for j in range(output_num_return_sequences_per_batch):
                effective_batch_idx = output_num_return_sequences_per_batch * i + j
                best_hyp = sorted_hyps.pop()[1]
                sent_lengths[effective_batch_idx] = len(best_hyp)
                best.append(best_hyp)

        # prepare for adding eos
        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
        decoded = input_ids.new(output_batch_size, sent_max_len)
        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            assert pad_token_id is not None, "`pad_token_id` has to be defined"
            decoded.fill_(pad_token_id)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, hypo in enumerate(best):
            decoded[i, :sent_lengths[i]] = hypo
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id

        return decoded
Example #5
0
def beam_search(model, cfg, encoder_input_ids, encoder_attention_mask,
                decoder_input_id, num_beams):
    #modified from https://github.com/huggingface/transformers/blob/master/src/transformers/generation_utils.py
    pad_token_id = model.config.pad_token_id
    eos_token_id = model.config.eos_token_id
    unk_token_id = model.config.unk_token_id
    batch_size = cfg.predict_batch_size * cfg.candidate_num
    candidate_num = cfg.candidate_num
    no_repeat_ngram_size = cfg.no_repeat_ngram_size
    repetition_penalty = cfg.repetition_penalty

    do_sample = cfg.do_sample
    top_p = cfg.top_p
    top_k = cfg.top_k
    temperature = cfg.temperature

    vocab_size = cfg.vocab_size
    max_length = cfg.max_output_length
    min_length = cfg.min_output_length
    length_penalty = 1.0
    num_return_sequences = num_beams

    encoder_attention_mask = model.attach_visual_for_mask(
        encoder_attention_mask)

    encoder_outputs = model.encoder(input_ids=encoder_input_ids,
                                    attention_mask=encoder_attention_mask,
                                    output_hidden_states=True)

    encoder_hidden_states = encoder_outputs.last_hidden_state.repeat(
        [num_beams * candidate_num, 1, 1])  #hidden_states[0]
    encoder_padding_mask = encoder_attention_mask.repeat(
        [num_beams * candidate_num, 1])

    generated_hyps = [
        BeamHypotheses(num_beams,
                       max_length,
                       length_penalty,
                       early_stopping=False) for _ in range(batch_size)
    ]

    decoder_input_id = decoder_input_id.unsqueeze(dim=1)
    input_ids = torch.repeat_interleave(decoder_input_id,
                                        repeats=num_beams,
                                        dim=1)
    input_ids = input_ids.view(batch_size * num_beams, -1)

    cur_len = 2

    beam_scores = torch.zeros((batch_size, num_beams),
                              dtype=torch.float,
                              device=encoder_input_ids.device)
    beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view(-1)

    done = [False for _ in range(batch_size)]

    while cur_len < max_length:

        decoder_padding_mask = make_padding_mask(input_ids, 0)
        bsz, tgt_len = input_ids.size()
        causal_mask = torch.triu(
            fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)),
            1).to(dtype=torch.float32, device=input_ids.device)

        decoder_outputs = model.model.decoder(
            input_ids,
            encoder_hidden_states,
            encoder_padding_mask=encoder_padding_mask,
            decoder_padding_mask=decoder_padding_mask,
            decoder_causal_mask=causal_mask)
        lm_logits = F.linear(decoder_outputs[0],
                             model.model.shared.weight,
                             bias=model.final_logits_bias)

        next_token_logits = lm_logits[:, -1, :]

        scores = F.log_softmax(next_token_logits, dim=-1)
        # set eos token prob to zero if min_length is not reached
        if cur_len < min_length:
            scores[:, eos_token_id] = -float("inf")
        scores[:, unk_token_id] = -float("inf")

        if repetition_penalty != 1.0:
            enforce_repetition_penalty_(
                scores,
                batch_size,
                num_beams,
                input_ids,
                repetition_penalty,
            )

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            num_batch_hypotheses = batch_size * num_beams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_batch_tokens = calc_banned_ngram_tokens(
                input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
            for i, banned_tokens in enumerate(banned_batch_tokens):
                scores[i, banned_tokens] = -float("inf")

        assert scores.shape == (
            batch_size * num_beams,
            vocab_size), "Shapes of scores: {} != {}".format(
                scores.shape, (batch_size * num_beams, vocab_size))

        if do_sample:
            _scores = scores + beam_scores[:, None].expand_as(
                scores)  # (batch_size * num_beams, vocab_size)
            # Temperature
            if temperature != 1.0:
                _scores = _scores / temperature
            # Top-p/top-k filtering
            _scores = top_k_top_p_filtering(
                _scores, top_k=top_k, top_p=top_p,
                min_tokens_to_keep=2)  # (batch_size * num_beams, vocab_size)
            # re-organize to group the beam together to sample from all beam_idxs
            _scores = _scores.contiguous().view(
                batch_size,
                num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)

            # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
            probs = F.softmax(_scores, dim=-1)
            next_tokens = torch.multinomial(
                probs,
                num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
            # Compute next scores
            next_scores = torch.gather(
                _scores, -1, next_tokens)  # (batch_size, num_beams * 2)
            # sort the sampled vector to make sure that the first num_beams samples are the best
            next_scores, next_scores_indices = torch.sort(next_scores,
                                                          descending=True,
                                                          dim=1)
            next_tokens = torch.gather(
                next_tokens, -1,
                next_scores_indices)  # (batch_size, num_beams * 2)
        else:
            next_scores = scores + beam_scores[:, None].expand_as(
                scores)  # (batch_size * num_beams, vocab_size)
            # re-organize to group the beam together (we are keeping top hypothesis accross beams)

            next_scores = next_scores.view(
                batch_size,
                num_beams * vocab_size)  # (batch_size, num_beams * vocab_size)

            next_scores, next_tokens = torch.topk(next_scores,
                                                  2 * num_beams,
                                                  dim=1,
                                                  largest=True,
                                                  sorted=True)

        assert next_scores.size() == next_tokens.size() == (batch_size,
                                                            2 * num_beams)

        next_batch_beam = []

        # for each sentence
        for batch_idx in range(batch_size):

            # if we are done with this sentence, add a pad token
            if done[batch_idx]:
                assert (
                    len(generated_hyps[batch_idx]) >= num_beams
                ), "Batch can only be done if at least {} beams have been generated".format(
                    num_beams)
                assert (
                    eos_token_id is not None and pad_token_id is not None
                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                next_batch_beam.extend([(0, pad_token_id, 0)] *
                                       num_beams)  # pad the batch
                continue

            # next sentence beam content, this will get added to next_batch_beam
            next_sent_beam = []

            # next tokens for this sentence
            for beam_token_rank, (beam_token_id,
                                  beam_token_score) in enumerate(
                                      zip(next_tokens[batch_idx],
                                          next_scores[batch_idx])):
                # get beam and token IDs
                beam_id = beam_token_id // vocab_size
                token_id = beam_token_id % vocab_size
                effective_beam_id = batch_idx * num_beams + beam_id

                # add to generated hypotheses if end of sentence
                if (eos_token_id is not None) and (token_id.item()
                                                   == eos_token_id):
                    # if beam_token does not belong to top num_beams tokens, it should not be added
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    generated_hyps[batch_idx].add(
                        input_ids[effective_beam_id].clone(),
                        beam_token_score.item(),
                    )
                else:
                    # add next predicted token since it is not eos_token
                    next_sent_beam.append(
                        (beam_token_score, token_id, effective_beam_id))
                # once the beam for next step is full, don't add more tokens to it.
                if len(next_sent_beam) == num_beams:
                    break
            # Check if we are done so that we can save a pad step if all(done)
            done[batch_idx] = done[
                batch_idx] or generated_hyps[batch_idx].is_done(
                    next_scores[batch_idx].max().item(), cur_len)

            # update next beam content
            assert len(
                next_sent_beam) == num_beams, "Beam should always be full"
            next_batch_beam.extend(next_sent_beam)
            assert len(next_batch_beam) == num_beams * (
                batch_idx + 1), "We should have added num_beams each step"

        # stop when we are done with each sentence
        if all(done):
            break

        # sanity check / prepare next batch
        assert len(next_batch_beam) == batch_size * num_beams
        beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
        beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
        beam_idx = input_ids.new([x[2] for x in next_batch_beam])

        # re-order batch and update current length
        input_ids = input_ids[beam_idx, :]
        input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
        cur_len = cur_len + 1

    # finalize all open beam hypotheses and add to generated hypotheses
    for batch_idx in range(batch_size):
        if done[batch_idx]:
            continue

        # test that beam scores match previously calculated scores if not eos and batch_idx not done
        if eos_token_id is not None and all(
            (token_id % vocab_size).item() != eos_token_id
                for token_id in next_tokens[batch_idx]):
            assert torch.all(
                next_scores[batch_idx, :num_beams] == beam_scores.view(
                    batch_size, num_beams)[batch_idx]
            ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
                next_scores[:, :num_beams][batch_idx],
                beam_scores.view(batch_size, num_beams)[batch_idx],
            )

        # need to add best num_beams hypotheses to generated hyps
        for beam_id in range(num_beams):
            effective_beam_id = batch_idx * num_beams + beam_id
            final_score = beam_scores[effective_beam_id].item()
            final_tokens = input_ids[effective_beam_id]
            generated_hyps[batch_idx].add(final_tokens, final_score)

    # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
    output_batch_size = batch_size * num_return_sequences
    output_num_return_sequences_per_batch = num_return_sequences

    # select the best hypotheses
    sent_lengths = input_ids.new(output_batch_size)
    best = []

    # retrieve best hypotheses
    for i, hypotheses in enumerate(generated_hyps):
        sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
        for j in range(output_num_return_sequences_per_batch):
            effective_batch_idx = output_num_return_sequences_per_batch * i + j
            best_hyp = sorted_hyps.pop()[1]
            sent_lengths[effective_batch_idx] = len(best_hyp)
            best.append(best_hyp)

    # shorter batches are padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined"
        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
        decoded = input_ids.new(output_batch_size,
                                sent_max_len).fill_(pad_token_id)

        # fill with hypothesis and eos_token_id if necessary
        for i, hypo in enumerate(best):
            decoded[i, :sent_lengths[i]] = hypo
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
    else:
        # none of the hypotheses have an eos_token
        assert (len(hypo) == max_length for hypo in best)
        decoded = torch.stack(best).type(torch.long).to(
            next(model.parameters()).device)

    return decoded
def _generate_no_beam_search(
    self,
    input_ids,
    cur_len,
    max_length,
    min_length,
    do_sample,
    temperature,
    top_k,
    top_p,
    repetition_penalty,
    no_repeat_ngram_size,
    bad_words_ids,
    bos_token_id,
    pad_token_id,
    eos_token_id,
    decoder_start_token_id,
    batch_size,
    encoder_outputs,
    attention_mask,
    use_cache,
    partial_generation_transform,
):
    """ Generate sequences for each example without beam search (num_beams == 1).
        All returned sequence are generated independantly.
    """
    # length of generated sentences / unfinished sentences
    unfinished_sents = input_ids.new(batch_size).fill_(1)
    sent_lengths = input_ids.new(batch_size).fill_(max_length)

    past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

    goodies = []
    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(
            input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
        )

        outputs = self(**model_inputs)
        next_token_logits = outputs[0][:, -1, :]

        # if model has past, then set the past variable to speed up decoding
        if use_cache:  # and self._use_cache(outputs, use_cache): !! removed:tdimson
            past = outputs[1]

        # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        if bad_words_ids is not None:
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            next_token_logits[:, eos_token_id] = -float("inf")

        if do_sample:
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            # Top-p/top-k filtering
            next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            # Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            # Greedy decoding
            next_token = torch.argmax(next_token_logits, dim=-1)

        # update generations and finished sentences
        if eos_token_id is not None:
            # pad finished sentences if eos_token_id exist
            tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
        else:
            tokens_to_add = next_token

        # tdimson: add a partial generation transform that can manipulate as we generate
        if partial_generation_transform is not None:
            tokens_to_add = partial_generation_transform(input_ids, tokens_to_add)

        input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        if eos_token_id is not None:
            eos_in_sents = tokens_to_add == eos_token_id
            # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
            is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
            sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
            # unfinished_sents is set to zero if eos in sentence
            unfinished_sents.mul_((~eos_in_sents).long())

            # tdimson: stop computing for sentences that have terminated
            for i in range(input_ids.size()[0]):
                if unfinished_sents[i] == 0:
                    goodies.append(input_ids[i, :])

            idx = unfinished_sents == 1
            input_ids = input_ids[idx, :]
            attention_mask = attention_mask[idx, :]
            sent_lengths = sent_lengths[idx]
            new_past = []
            for item in past:
                new_past.append(item[:, idx, :, :, :])
            past = tuple(new_past)
            unfinished_sents = unfinished_sents[idx]

            if input_ids.size()[0] == 0:
                break

        # stop when there is a </s> in each sentence, or if we exceed the maximul length
        if unfinished_sents.max() == 0:
            break

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)

        cur_len = cur_len + 1

    # tdimson: take our finished sentences and roll
    for i in range(input_ids.size()[0]):
        goodies.append(input_ids[i])

    from torch.nn.utils.rnn import pad_sequence

    return pad_sequence(goodies, batch_first=True, padding_value=pad_token_id)

    # if there are different sentences lengths in the batch, some batches have to be padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
        # finished sents are filled with pad_token
        decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
    else:
        decoded = input_ids

    for hypo_idx, hypo in enumerate(input_ids):
        decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]

    return decoded