Beispiel #1
0
    def complete_image(self,
                       text_tokens,
                       image_tokens=None,
                       num_return_sequences=1,
                       num_beams=1,
                       top_k=None,
                       top_p=None,
                       temperature=None,
                       _verbose=False):
        config = self.config
        processors = None  # there are no pre-processors that we need
        batch_size = text_tokens.shape[0] * num_return_sequences
        device = text_tokens.device

        # we have an equivalent of is_beam_sample_gen_mode
        logits_warper_list = self._get_logits_warper_list(
            num_beams, temperature, top_k, top_p)

        # this model always generates to a fixed size
        steps_to_gen = config.total_context_len - config.text_context_len
        if image_tokens is not None:
            steps_to_gen -= image_tokens.shape[1]

        if _verbose:
            print("steps_to_gen:", steps_to_gen)
            print(steps_to_gen, image_tokens, config.total_context_len)

        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            max_length=steps_to_gen,
            num_beams=num_beams,
            device=device,
            length_penalty=1.0,
            do_early_stopping=False,
        )

        # expand text_tokens and image_tokens to num_beams

        expanded_text_tokens = self._expand_tokens_for_generation(
            text_tokens, expand_size=num_beams * num_return_sequences)
        if image_tokens is not None:
            expanded_image_tokens = self._expand_tokens_for_generation(
                image_tokens, expand_size=num_beams * num_return_sequences)
        else:
            expanded_image_tokens = None

        image_tokens = self.beam_sample(text_tokens=expanded_text_tokens,
                                        image_tokens=expanded_image_tokens,
                                        beam_scorer=beam_scorer,
                                        logits_warper=logits_warper_list,
                                        steps_to_gen=steps_to_gen,
                                        batch_size=batch_size,
                                        num_beams=num_beams,
                                        _verbose=_verbose)
        if _verbose:
            print("final image tokens", image_tokens["sequences"],
                  image_tokens["sequences"][0].size())
        recons = self.vae._decode_ids(
            image_tokens=image_tokens["sequences"]).permute((0, 2, 3, 1))
        return recons, image_tokens["sequence_scores"]
 def prepare_beam_scorer(self, **kwargs):
     return BeamSearchScorer(
         batch_size=kwargs.get("batch_size", self.batch_size),
         num_beams=kwargs.get("num_beams", self.num_beams),
         device=torch_device,
         length_penalty=kwargs.get("length_penalty", self.length_penalty),
         do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
         num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
     )
 def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
     beam_kwargs = {
         "early_stopping": False,
         "length_penalty": 2.0,
         "num_beams": 2,
         "num_return_sequences": num_return_sequences,
     }
     beam_scorer = BeamSearchScorer(
         batch_size=batch_size,
         max_length=max_length,
         num_beams=beam_kwargs["num_beams"],
         device=torch_device,
         length_penalty=beam_kwargs["length_penalty"],
         do_early_stopping=beam_kwargs["early_stopping"],
         num_beam_hyps_to_keep=num_return_sequences,
     )
     return beam_kwargs, beam_scorer
Beispiel #4
0
    def generate(self,
                 input_ids: Optional[torch.LongTensor] = None,
                 max_length: Optional[int] = None,
                 min_length: Optional[int] = None,
                 do_sample: Optional[bool] = None,
                 early_stopping: Optional[bool] = None,
                 num_beams: Optional[int] = None,
                 temperature: Optional[float] = None,
                 top_k: Optional[int] = None,
                 top_p: Optional[float] = None,
                 repetition_penalty: Optional[float] = None,
                 bad_words_ids: Optional[Iterable[int]] = None,
                 bos_token_id: Optional[int] = None,
                 pad_token_id: Optional[int] = None,
                 eos_token_id: Optional[int] = None,
                 length_penalty: Optional[float] = None,
                 no_repeat_ngram_size: Optional[int] = None,
                 num_return_sequences: Optional[int] = None,
                 decoder_start_token_id: Optional[int] = None,
                 use_cache: Optional[bool] = None,
                 num_beam_groups: Optional[int] = None,
                 diversity_penalty: Optional[float] = None,
                 postfix_additional_tokens_fn: Optional[Callable[
                     [torch.Tensor], torch.Tensor]] = None,
                 prefix_allowed_tokens_fn: Optional[Callable[
                     [int, torch.Tensor], List[int]]] = None,
                 **model_kwargs) -> torch.LongTensor:

        # set init values
        num_beams = num_beams if num_beams is not None else self.config.num_beams
        num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
        max_length = max_length if max_length is not None else self.config.max_length
        do_sample = do_sample if do_sample is not None else self.config.do_sample
        num_return_sequences = (num_return_sequences
                                if num_return_sequences is not None else
                                self.config.num_return_sequences)

        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id

        if input_ids is None:
            # init `input_ids` with bos_token_id
            input_ids = self._prepare_input_ids_for_generation(bos_token_id)

        if model_kwargs.get("attention_mask", None) is None:
            # init `attention_mask` depending on `pad_token_id`
            model_kwargs[
                "attention_mask"] = self._prepare_attention_mask_for_generation(
                    input_ids, pad_token_id, eos_token_id)

        # special case if pad_token_id is not defined
        if pad_token_id is None and eos_token_id is not None:
            logger.warning(
                f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
            )
            pad_token_id = eos_token_id

        if self.config.is_encoder_decoder:
            # add encoder_outputs to model_kwargs
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                input_ids, model_kwargs)

            # set input_ids as decoder_input_ids
            input_ids = self._prepare_decoder_input_ids_for_generation(
                input_ids,
                decoder_start_token_id=decoder_start_token_id,
                bos_token_id=bos_token_id,
                **model_kwargs)

            if "encoder_outputs" not in model_kwargs or not isinstance(
                    model_kwargs["encoder_outputs"], ModelOutput):
                raise ValueError(
                    "Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`."
                )

        # determine generation mode
        is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups
                                                   == 1) and do_sample is False
        is_sample_gen_mode = (num_beams == 1) and (num_beam_groups
                                                   == 1) and do_sample is True
        is_beam_gen_mode = (num_beams > 1) and (num_beam_groups
                                                == 1) and do_sample is False
        is_beam_sample_gen_mode = (num_beams > 1) and (
            num_beam_groups == 1) and do_sample is True
        is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
        if num_beam_groups > num_beams:
            raise ValueError(
                "`num_beam_groups` has to be smaller or equal to `num_beams`")
        if is_group_beam_gen_mode and do_sample is True:
            raise ValueError(
                "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
            )

        # set model_kwargs
        model_kwargs["use_cache"] = use_cache

        # get distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=bad_words_ids,
            min_length=min_length,
            eos_token_id=eos_token_id,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            num_beams=num_beams,
            num_beam_groups=num_beam_groups,
            diversity_penalty=diversity_penalty,
        )

        if is_greedy_gen_mode:
            if num_return_sequences > 1:
                raise ValueError(
                    f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
                )

            # greedy search
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
                max_length=max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                **model_kwargs,
            )

        elif is_sample_gen_mode:
            # get probability distribution warper
            logits_warper = self._get_logits_warper(top_k=top_k,
                                                    top_p=top_p,
                                                    temperature=temperature,
                                                    num_beams=num_beams)

            # expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # sample
            return self.sample(
                input_ids,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                max_length=max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                **model_kwargs,
            )

        elif is_beam_gen_mode:
            batch_size = input_ids.shape[0]

            length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
            early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

            if num_return_sequences > num_beams:
                raise ValueError(
                    "`num_return_sequences` has to be smaller or equal to `num_beams`."
                )

            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                max_length=max_length,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
            )
            # interleave with `num_beams`
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs)
            return self.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                max_length=max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                postfix_additional_tokens_fn=postfix_additional_tokens_fn,
                **model_kwargs,
            )

        elif is_beam_sample_gen_mode:
            logits_warper = self._get_logits_warper(top_k=top_k,
                                                    top_p=top_p,
                                                    temperature=temperature,
                                                    num_beams=num_beams)

            batch_size = input_ids.shape[0] * num_return_sequences

            length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                max_length=max_length,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
            )

            # interleave with `num_beams * num_return_sequences`
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams * num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            print('foo')

            return self.beam_sample(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                logits_warper=logits_warper,
                max_length=max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                postfix_additional_tokens_fn=postfix_additional_tokens_fn,
                **model_kwargs,
            )

        elif is_group_beam_gen_mode:
            batch_size = input_ids.shape[0]

            length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
            early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

            if num_return_sequences > num_beams:
                raise ValueError(
                    "`num_return_sequences` has to be smaller or equal to `num_beams`."
                )

            if num_beams % num_beam_groups != 0:
                raise ValueError(
                    "`num_beams` should be divisible by `num_beam_groups` for group beam search."
                )

            diverse_beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                max_length=max_length,
                num_beams=num_beams,
                device=self.device,
                length_penalty=length_penalty,
                do_early_stopping=early_stopping,
                num_beam_hyps_to_keep=num_return_sequences,
                num_beam_groups=num_beam_groups,
            )
            # interleave with `num_beams`
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids,
                expand_size=num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs)
            return self.group_beam_search(
                input_ids,
                diverse_beam_scorer,
                logits_processor=logits_processor,
                max_length=max_length,
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                **model_kwargs,
            )