示例#1
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        """
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
            `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`.
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        """
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
示例#2
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
        indexer: PretrainedTransformerIndexer = None,
        encoder: Seq2SeqEncoder = None,
        **kwargs,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead."
        )
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
        self._beam_search = beam_search.construct(
            end_index=self._end_id, vocab=self.vocab, **beam_search_extras
        )

        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (
                encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size
            )
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
示例#3
0
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'TreeAttention':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)

        premise_encoder_params = params.pop("premise_encoder", None)
        premise_encoder = Seq2SeqEncoder.from_params(premise_encoder_params)

        attention_similarity = SimilarityFunction.from_params(params.pop('attention_similarity'))
        phrase_probability = FeedForward.from_params(params.pop('phrase_probability'))
        edge_probability = FeedForward.from_params(params.pop('edge_probability'))

        edge_embedding = Embedding.from_params(vocab, params.pop('edge_embedding'))
        use_encoding_for_node = params.pop('use_encoding_for_node')
        ignore_edges = params.pop('ignore_edges', False)

        init_params = params.pop('initializer', None)
        initializer = (InitializerApplicator.from_params(init_params)
                       if init_params is not None
                       else InitializerApplicator())

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   phrase_probability=phrase_probability,
                   edge_probability=edge_probability,
                   premise_encoder=premise_encoder,
                   edge_embedding=edge_embedding,
                   use_encoding_for_node=use_encoding_for_node,
                   attention_similarity=attention_similarity,
                   ignore_edges=ignore_edges,
                   initializer=initializer)
示例#4
0
    def from_params(cls, params: Params) -> 'Seq2Seq2VecEncoder':
        seq2seq_encoder_params = params.pop("seq2seq_encoder")
        seq2vec_encoder_params = params.pop("seq2vec_encoder")
        seq2seq_encoder = Seq2SeqEncoder.from_params(seq2seq_encoder_params)
        seq2vec_encoder = Seq2VecEncoder.from_params(seq2vec_encoder_params)

        return cls(seq2seq_encoder=seq2seq_encoder,
                   seq2vec_encoder=seq2vec_encoder)
    def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._encoder = encoder
        self._classifier = nn.Linear(in_features=2 * encoder.get_output_dim(),
                                     out_features=2)

        self._f1 = F1Measure(positive_label=1)
示例#6
0
 def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder,
              encoder: Seq2SeqEncoder) -> None:
     super().__init__(vocab)
     self._embedder = embedder
     self._encoder = encoder
     self._classifier = nn.Linear(
         in_features=encoder.get_output_dim() * 2,
         out_features=vocab.get_vocab_size('labels'))
     self._metric = F1Measure(positive_label=vocab.get_token_index(
         token="positive", namespace='labels'))
示例#7
0
    def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._encoder = encoder
        self._classifier = torch.nn.Linear(
            in_features=encoder.get_output_dim(),
            out_features=vocab.get_vocab_size('labels'))

        self._f1 = SpanBasedF1Measure(vocab, 'labels', 'IOB1')
示例#8
0
    def __init__(self, vocab: Vocabulary, embedder: elmoembedder,
                 encoder: Seq2SeqEncoder) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._encoder = encoder
        self._classifier = torch.nn.Linear(
            in_features=encoder.get_output_dim(),
            out_features=vocab.get_vocab_size("labels"),
        )

        self._f1 = SpanBasedF1Measure(vocab, "labels", "IOB1")
示例#9
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
    def __init__(self, vocab: Vocabulary, embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._encoder = encoder
        self._classifier = torch.nn.Linear(
            in_features=2 * encoder.get_output_dim(),
            out_features=vocab.get_vocab_size('labels'))

        #self._ffnn = torch.nn.Sequential(
        ##linear
        ##dropout
        ##tanh
        ##linear
        #)
        # define f1 here, use as plain F1 measure not spanBased
        self._metric = F1Measure(positive_label=vocab.get_token_index(
            token='positive', namespace='labels'))
示例#11
0
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 message_encoder: Seq2VecEncoder,
                 conversation_encoder: Seq2SeqEncoder,
                 dropout: float = 0.5,
                 pos_weight: float = None,
                 use_game_scores: bool = False) -> None:
        super().__init__(vocab)

        self._embedder = embedder
        self._message_encoder = message_encoder
        self._conversation_encoder = conversation_encoder
        self._use_game_scores = use_game_scores

        output_dim = conversation_encoder.get_output_dim() + int(self._use_game_scores)

        self._classifier = nn.Linear(in_features=output_dim,
                                     out_features=vocab.get_vocab_size('labels'))
        self._dropout = nn.Dropout(dropout)

        self._label_index_to_token = vocab.get_index_to_token_vocabulary(namespace="labels")
        self._num_labels = len(self._label_index_to_token)
        print(self._label_index_to_token)
        index_list = list(range(self._num_labels))
        print(index_list)
        self._f1 = FBetaMeasure(average=None, labels=index_list)
        self._f1_micro = FBetaMeasure(average='micro')
        self._f1_macro = FBetaMeasure(average='macro')

        if pos_weight is None or pos_weight <= 0:
            labels_counter = self.vocab._retained_counter['labels']
            self._pos_weight = 1. * labels_counter['True'] / labels_counter['False']
            # self._pos_weight = 15.886736214605067
            print('Computing Pos weight from labels:', self._pos_weight)
        else:
            self._pos_weight = float(pos_weight)
示例#12
0
    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> PytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
        if not params.pop_bool('batch_first', True):
            raise ConfigurationError("Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2SeqWrapper(module)

# pylint: disable=protected-access
Seq2SeqEncoder.register("gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("rnn")(_Seq2SeqWrapper(torch.nn.RNN))
Seq2SeqEncoder.register("augmented_lstm")(_Seq2SeqWrapper(AugmentedLstm))
Seq2SeqEncoder.register("alternating_lstm")(_Seq2SeqWrapper(StackedAlternatingLstm))
if torch.cuda.is_available():
    try:
        # TODO(Mark): Remove this once we have a CPU wrapper for the kernel/switch to ATen.
        from allennlp.modules.alternating_highway_lstm import AlternatingHighwayLSTM
        Seq2SeqEncoder.register("alternating_highway_lstm_cuda")(_Seq2SeqWrapper(AlternatingHighwayLSTM))
    except (ModuleNotFoundError, FileNotFoundError, ImportError):
        logger.debug("allennlp could not register 'alternating_highway_lstm_cuda' - installation "
                     "needs to be completed manually if you have pip-installed the package. "
                     "Run ``bash make.sh`` in the 'custom_extensions' module on a machine with a "
                     "GPU.")
示例#13
0
文件: lstm.py 项目: Chung-I/tsm-rnnt
        final_states: Tuple[torch.Tensor, torch.Tensor]
            The per-layer final (state, memory) states of the LSTM, each with shape
            (num_layers, batch_size, hidden_size).
        """
        def init_hidden(tensor, shape):
            return (tensor.new_zeros(shape),
                    tensor.new_zeros(shape))

        sequence_tensor, batch_lengths = pad_packed_sequence(inputs, batch_first=True)
        batch_size, _, _ = sequence_tensor.size()
        hidden_shape = torch.Size((batch_size, self.hidden_channel * self.hidden_size)) if self.conv \
            else torch.Size((batch_size, self.hidden_size))

        if not initial_state:
            hidden_states = [[init_hidden(sequence_tensor, hidden_shape) for _ in range(2)] if self.bidirectional
                             else init_hidden(sequence_tensor, hidden_shape)
                             for _ in range(self.num_layers)]
        elif initial_state[0].size()[0] != self.num_layers:
            raise ConfigurationError("Initial states were passed to forward() but the number of "
                                     "initial states does not match the number of layers.")
        else:
            hidden_states = list(zip(initial_state[0].split(1, 0),
                                     initial_state[1].split(1, 0)))
        outputs, states = self.rnn(sequence_tensor.transpose(1, 0), hidden_states)
        outputs = pack_padded_sequence(outputs, batch_lengths)
        return outputs, states



Seq2SeqEncoder.register("stacked_custom_lstm")(_Seq2SeqWrapper(StackedCustomLstm))
示例#14
0
 def test_registry_has_builtin_seq2seq_encoders(self):
     # pylint: disable=protected-access
     assert Seq2SeqEncoder.by_name('gru')._module_class.__name__ == 'GRU'
     assert Seq2SeqEncoder.by_name('lstm')._module_class.__name__ == 'LSTM'
     assert Seq2SeqEncoder.by_name('rnn')._module_class.__name__ == 'RNN'
示例#15
0
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(
                *new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(
                0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size,
                                                       sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination,
                                                [tokens, attended_sentence])
        return self._output_projection(combined_tensors)


IntraSentenceAttentionEncoder = Seq2SeqEncoder.register(
    u"intra_sentence_attention")(IntraSentenceAttentionEncoder)
示例#16
0
    to ``self``.  Then when called (as if we were instantiating an actual encoder with
    ``Encoder(**params)``, or with ``Encoder.from_params(params)``), we pass those parameters
    through to the ``RNNBase`` constructor, then pass the instantiated pytorch RNN to the
    ``PytorchSeq2SeqWrapper``.  This lets us use this class in the registry and have everything just
    work.
    """

    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]

    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> StackedPytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> StackedPytorchSeq2SeqWrapper:
        if not params.pop_bool("batch_first", True):
            raise ConfigurationError(
                "Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params["batch_first"] = True
        stateful = params.pop_bool("stateful", False)
        module = self._module_class(**params.as_dict(infer_type_and_cast=True))
        return StackedPytorchSeq2SeqWrapper(module, stateful=stateful)


Seq2SeqEncoder.register("stacked_lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("stacked_gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("stacked_rnn")(_Seq2SeqWrapper(torch.nn.RNN))
示例#17
0
    through to the ``RNNBase`` constructor, then pass the instantiated pytorch RNN to the
    ``PytorchSeq2SeqWrapper``.  This lets us use this class in the registry and have everything just
    work.
    """

    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]

    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> PytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
        if not params.pop_bool("batch_first", True):
            raise ConfigurationError("Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params["batch_first"] = True
        stateful = params.pop_bool("stateful", False)
        module = self._module_class(**params.as_dict(infer_type_and_cast=True))
        return PytorchSeq2SeqWrapper(module, stateful=stateful)


Seq2SeqEncoder.register("gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("rnn")(_Seq2SeqWrapper(torch.nn.RNN))
Seq2SeqEncoder.register("augmented_lstm")(_Seq2SeqWrapper(AugmentedLstm))
Seq2SeqEncoder.register("alternating_lstm")(_Seq2SeqWrapper(StackedAlternatingLstm))
Seq2SeqEncoder.register("stacked_bidirectional_lstm")(_Seq2SeqWrapper(StackedBidirectionalLstm))
示例#18
0
文件: __init__.py 项目: Shuailong/SPM
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Author: Shuailong
# @Email: [email protected]
# @Date: 2019-04-16 22:10:41
# @Last Modified by: Shuailong
# @Last Modified time: 2019-05-05 20:27:54

from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder

from spm.modules.slstm import SLSTMEncoder

Seq2SeqEncoder.register("slstm")(SLSTMEncoder)
    work.
    """
    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]

    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> PytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
        if not params.pop_bool('batch_first', True):
            raise ConfigurationError("Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        stateful = params.pop_bool('stateful', False)
        module = self._module_class(**params.as_dict())
        return PytorchSeq2SeqWrapper(module, stateful=stateful)

# pylint: disable=protected-access
Seq2SeqEncoder.register("gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("rnn")(_Seq2SeqWrapper(torch.nn.RNN))
Seq2SeqEncoder.register("augmented_lstm")(_Seq2SeqWrapper(AugmentedLstm))
Seq2SeqEncoder.register("alternating_lstm")(_Seq2SeqWrapper(StackedAlternatingLstm))
Seq2SeqEncoder.register("stacked_bidirectional_lstm")(_Seq2SeqWrapper(StackedBidirectionalLstm))
Seq2SeqEncoder.register("bidirectional_language_model_transformer")(
        BidirectionalLanguageModelTransformer
)
示例#20
0
    def __call__(self, **kwargs) -> PytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
        if not params.pop_bool('batch_first', True):
            raise ConfigurationError(
                "Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2SeqWrapper(module)


# pylint: disable=protected-access
Seq2SeqEncoder.register("gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("rnn")(_Seq2SeqWrapper(torch.nn.RNN))
Seq2SeqEncoder.register("augmented_lstm")(_Seq2SeqWrapper(AugmentedLstm))
Seq2SeqEncoder.register("alternating_lstm")(
    _Seq2SeqWrapper(StackedAlternatingLstm))
if torch.cuda.is_available():
    try:
        # TODO(Mark): Remove this once we have a CPU wrapper for the kernel/switch to ATen.
        from allennlp.modules.alternating_highway_lstm import AlternatingHighwayLSTM
        Seq2SeqEncoder.register("alternating_highway_lstm_cuda")(
            _Seq2SeqWrapper(AlternatingHighwayLSTM))
    except (ModuleNotFoundError, FileNotFoundError):
        logger.debug(
            "allennlp could not register 'alternating_highway_lstm_cuda' - installation "
            "needs to be completed manually if you have pip-installed the package. "
示例#21
0
    def test_registry_has_builtin_seq2seq_encoders(self):

        assert Seq2SeqEncoder.by_name("gru")._module_class.__name__ == "GRU"
        assert Seq2SeqEncoder.by_name("lstm")._module_class.__name__ == "LSTM"
        assert Seq2SeqEncoder.by_name("rnn")._module_class.__name__ == "RNN"
示例#22
0
            The encoded sequence of shape (batch_size, sequence_length, hidden_size)
        final_states: Tuple[torch.Tensor, torch.Tensor]
            The per-layer final (state, memory) states of the LSTM, each with shape
            (num_layers, batch_size, hidden_size).
        """
        if not initial_state:
            hidden_states = [None] * len(self.lstm_layers)
        elif initial_state[0].size()[0] != len(self.lstm_layers):
            raise ConfigurationError(
                "Initial states were passed to forward() but the number of "
                "initial states does not match the number of layers.")
        else:
            hidden_states = list(
                zip(initial_state[0].split(1, 0), initial_state[1].split(1,
                                                                         0)))

        output_sequence = inputs
        final_states = []
        for i, state in enumerate(hidden_states):
            layer = getattr(self, 'layer_{}'.format(i))
            # The state is duplicated to mirror the Pytorch API for LSTMs.
            output_sequence, final_state = layer(output_sequence, state)
            final_states.append(final_state)

        final_hidden_state, final_cell_state = tuple(
            torch.cat(state_list, 0) for state_list in zip(*final_states))
        return output_sequence, (final_hidden_state, final_cell_state)


Seq2SeqEncoder.register("stacked_lstm")(_Seq2SeqWrapper(StackedLstm))
示例#23
0
        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(
            scaled_similarities,
            mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps,
                               int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs


MultiHeadSelfAttention = Seq2SeqEncoder.register(u"multi_head_self_attention")(
    MultiHeadSelfAttention)
示例#24
0
 def test_registry_has_builtin_seq2seq_encoders(self):
     # pylint: disable=protected-access
     assert Seq2SeqEncoder.by_name('gru')._module_class.__name__ == 'GRU'
     assert Seq2SeqEncoder.by_name('lstm')._module_class.__name__ == 'LSTM'
     assert Seq2SeqEncoder.by_name('rnn')._module_class.__name__ == 'RNN'
    #overrides
    def forward(self, inputs              , mask              ): # pylint: disable=arguments-differ
        if self._use_positional_encoding:
            output = add_positional_features(inputs)
        else:
            output = inputs
        for (attention,
             feedforward,
             feedforward_layer_norm,
             layer_norm) in izip(self._attention_layers,
                                self._feedfoward_layers,
                                self._feed_forward_layer_norm_layers,
                                self._layer_norm_layers):
            cached_input = output
            # Project output of attention encoder through a feedforward
            # network and back to the input size for the next layer.
            # shape (batch_size, timesteps, input_size)
            feedforward_output = feedforward(output)
            feedforward_output = self.dropout(feedforward_output)
            if feedforward_output.size() == cached_input.size():
                # First layer might have the wrong size for highway
                # layers, so we exclude it here.
                feedforward_output = feedforward_layer_norm(feedforward_output + cached_input)
            # shape (batch_size, sequence_length, hidden_dim)
            attention_output = attention(feedforward_output, mask)
            output = layer_norm(self.dropout(attention_output) + feedforward_output)

        return output

StackedSelfAttentionEncoder = Seq2SeqEncoder.register(u"stacked_self_attention")(StackedSelfAttentionEncoder)
示例#26
0
    def __call__(self, **kwargs) -> PytorchSeq2SeqWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> PytorchSeq2SeqWrapper:
        if not params.pop_bool('batch_first', True):
            raise ConfigurationError(
                "Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2SeqWrapper(module, self._stateful)


# pylint: disable=protected-access
Seq2SeqEncoder.register("gru")(_Seq2SeqWrapper(torch.nn.GRU))
Seq2SeqEncoder.register("lstm")(_Seq2SeqWrapper(torch.nn.LSTM))
Seq2SeqEncoder.register("rnn")(_Seq2SeqWrapper(torch.nn.RNN))
Seq2SeqEncoder.register("stateful_gru")(_Seq2SeqWrapper(torch.nn.GRU, True))
Seq2SeqEncoder.register("stateful_lstm")(_Seq2SeqWrapper(torch.nn.LSTM, True))
Seq2SeqEncoder.register("stateful_rnn")(_Seq2SeqWrapper(torch.nn.RNN, True))
Seq2SeqEncoder.register("augmented_lstm")(_Seq2SeqWrapper(AugmentedLstm))
Seq2SeqEncoder.register("alternating_lstm")(
    _Seq2SeqWrapper(StackedAlternatingLstm))
Seq2SeqEncoder.register("stacked_bidirectional_lstm")(
    _Seq2SeqWrapper(StackedBidirectionalLstm))
if torch.cuda.is_available():
    try:
        # TODO(Mark): Remove this once we have a CPU wrapper for the kernel/switch to ATen.
        from allennlp.modules.alternating_highway_lstm import AlternatingHighwayLSTM
        Seq2SeqEncoder.register("alternating_highway_lstm_cuda")(
示例#27
0
            output, final_state = layer(output_sequence, state)

            output_sequence, lengths = pad_packed_sequence(output,
                                                           batch_first=True)

            # Apply layer wise dropout on each output sequence apart from the
            # first (input) and last
            if i < (self.num_layers - 1):
                output_sequence = self.layer_dropout(output_sequence)
                if self.use_residual:
                    res_proj = getattr(self, 'res_proj_{}'.format(i))
                    tmp = output_sequence
                    output_sequence = output_sequence + res_proj(prev_sequence)
                    prev_sequence = tmp

            output_sequence = pack_padded_sequence(output_sequence,
                                                   lengths,
                                                   batch_first=True)

            final_h.append(final_state[0])
            final_c.append(final_state[1])

        final_h = torch.cat(final_h, dim=0)
        final_c = torch.cat(final_c, dim=0)
        final_state_tuple = (final_h, final_c)
        return output_sequence, final_state_tuple


Seq2SeqEncoder.register("residual_bidirectional_lstm")(
    _Seq2SeqWrapper(ResidualBidirectionalLstm))