예제 #1
0
    def forward(self, tokens, mask=None):  #pylint: disable=arguments-differ
        if mask is not None:
            tokens = tokens * mask.unsqueeze(-1).float()

        # Our input has shape `(batch_size, num_tokens, embedding_dim)`, so we sum out the `num_tokens`
        # dimension.
        summed = tokens.sum(1)

        if self._averaged:
            if mask is not None:
                lengths = get_lengths_from_binary_sequence_mask(mask)
                length_mask = (lengths > 0)

                # Set any length 0 to 1, to avoid dividing by zero.
                lengths = torch.max(lengths, lengths.new_ones(1))
            else:
                lengths = tokens.new_full((1, ), fill_value=tokens.size(1))
                length_mask = None

            summed = summed / lengths.unsqueeze(-1).float()

            if length_mask is not None:
                summed = summed * (length_mask > 0).float().unsqueeze(-1)

        return summed


BagOfEmbeddingsEncoder = Seq2VecEncoder.register(u"boe")(
    BagOfEmbeddingsEncoder)
예제 #2
0
    work.
    """
    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]

    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> PytorchSeq2VecWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> PytorchSeq2VecWrapper:
        if not params.pop('batch_first', True):
            raise ConfigurationError(
                "Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2VecWrapper(module)


# pylint: disable=protected-access
Seq2VecEncoder.register("gru")(_Seq2VecWrapper(torch.nn.GRU))
Seq2VecEncoder.register("lstm")(_Seq2VecWrapper(torch.nn.LSTM))
Seq2VecEncoder.register("rnn")(_Seq2VecWrapper(torch.nn.RNN))
Seq2VecEncoder.register("augmented_lstm")(_Seq2VecWrapper(AugmentedLstm))
Seq2VecEncoder.register("alternating_lstm")(
    _Seq2VecWrapper(StackedAlternatingLstm))
Seq2VecEncoder.register("stacked_bidirectional_lstm")(
    _Seq2VecWrapper(StackedBidirectionalLstm))
예제 #3
0
        # Our input is expected to have shape `(batch_size, num_tokens, embedding_dim)`.  The
        # convolution layers expect input of shape `(batch_size, in_channels, sequence_length)`,
        # where the conv layer `in_channels` is our `embedding_dim`.  We thus need to transpose the
        # tensor first.
        tokens = torch.transpose(tokens, 1, 2)
        # Each convolution layer returns output of size `(batch_size, num_filters, pool_length)`,
        # where `pool_length = num_tokens - ngram_size + 1`.  We then do an activation function,
        # then do max pooling over each filter for the whole input sequence.  Because our max
        # pooling is simple, we just use `torch.max`.  The resultant tensor of has shape
        # `(batch_size, num_conv_layers * num_filters)`, which then gets projected using the
        # projection layer, if requested.

        filter_outputs = []
        for i in range(len(self._convolution_layers)):
            convolution_layer = getattr(self, u'conv_layer_{}'.format(i))
            filter_outputs.append(
                    self._activation(convolution_layer(tokens)).max(dim=2)[0]
            )

        # Now we have a list of `num_conv_layers` tensors of shape `(batch_size, num_filters)`.
        # Concatenating them gives us a tensor of shape `(batch_size, num_filters * num_conv_layers)`.
        maxpool_output = torch.cat(filter_outputs, dim=1) if len(filter_outputs) > 1 else filter_outputs[0]

        if self.projection_layer:
            result = self.projection_layer(maxpool_output)
        else:
            result = maxpool_output
        return result

CnnEncoder = Seq2VecEncoder.register(u"cnn")(CnnEncoder)
예제 #4
0
    When you instantiate a ``_Wrapper`` object, you give it an ``RNNBase`` subclass, which we save
    to ``self``.  Then when called (as if we were instantiating an actual encoder with
    ``Encoder(**params)``, or with ``Encoder.from_params(params)``), we pass those parameters
    through to the ``RNNBase`` constructor, then pass the instantiated pytorch RNN to the
    ``PytorchSeq2VecWrapper``.  This lets us use this class in the registry and have everything just
    work.
    """
    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]
    def __init__(self, module_class: Type[torch.nn.modules.RNNBase]) -> None:
        self._module_class = module_class

    def __call__(self, **kwargs) -> PytorchSeq2VecWrapper:
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params: Params) -> PytorchSeq2VecWrapper:
        if not params.pop('batch_first', True):
            raise ConfigurationError("Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params['batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2VecWrapper(module)

# pylint: disable=protected-access
Seq2VecEncoder.register("gru")(_Seq2VecWrapper(torch.nn.GRU))
Seq2VecEncoder.register("lstm")(_Seq2VecWrapper(torch.nn.LSTM))
Seq2VecEncoder.register("rnn")(_Seq2VecWrapper(torch.nn.RNN))
Seq2VecEncoder.register("augmented_lstm")(_Seq2VecWrapper(AugmentedLstm))
Seq2VecEncoder.register("alternating_lstm")(_Seq2VecWrapper(StackedAlternatingLstm))
예제 #5
0
    through to the ``RNNBase`` constructor, then pass the instantiated pytorch RNN to the
    ``PytorchSeq2VecWrapper``.  This lets us use this class in the registry and have everything just
    work.
    """
    PYTORCH_MODELS = [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]

    def __init__(self, module_class):
        self._module_class = module_class

    def __call__(self, **kwargs):
        return self.from_params(Params(kwargs))

    # Logic requires custom from_params
    def from_params(self, params):
        if not params.pop(u'batch_first', True):
            raise ConfigurationError(
                u"Our encoder semantics assumes batch is always first!")
        if self._module_class in self.PYTORCH_MODELS:
            params[u'batch_first'] = True
        module = self._module_class(**params.as_dict())
        return PytorchSeq2VecWrapper(module)


# pylint: disable=protected-access
Seq2VecEncoder.register(u"gru")(_Seq2VecWrapper(torch.nn.GRU))
Seq2VecEncoder.register(u"lstm")(_Seq2VecWrapper(torch.nn.LSTM))
Seq2VecEncoder.register(u"rnn")(_Seq2VecWrapper(torch.nn.RNN))
Seq2VecEncoder.register(u"augmented_lstm")(_Seq2VecWrapper(AugmentedLstm))
Seq2VecEncoder.register(u"alternating_lstm")(
    _Seq2VecWrapper(StackedAlternatingLstm))