Ejemplo n.º 1
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        """
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
            `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`.
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        """
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Ejemplo n.º 2
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
        indexer: PretrainedTransformerIndexer = None,
        encoder: Seq2SeqEncoder = None,
        **kwargs,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead."
        )
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
        self._beam_search = beam_search.construct(
            end_index=self._end_id, vocab=self.vocab, **beam_search_extras
        )

        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (
                encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size
            )
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Ejemplo n.º 3
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )