示例#1
0
    def __init__(self,
                 encoder_output_dim: int,
                 action_embedding_dim: int,
                 attention_function: SimilarityFunction,
                 dropout: float = 0.0,
                 use_coverage: bool = False) -> None:
        super(NlvrDecoderStep, self).__init__()
        self._input_attention = LegacyAttention(attention_function)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        output_dim = encoder_output_dim
        input_dim = output_dim
        # Our decoder input will be the concatenation of the decoder hidden state and the previous
        # action embedding, and we'll project that down to the decoder's `input_dim`, which we
        # arbitrarily set to be the same as `output_dim`.
        self._input_projection_layer = Linear(
            output_dim + action_embedding_dim, input_dim)
        # Before making a prediction, we'll compute an attention over the input given our updated
        # hidden state. Then we concatenate those with the decoder state and project to
        # `action_embedding_dim` to make a prediction.
        self._output_projection_layer = Linear(output_dim + encoder_output_dim,
                                               action_embedding_dim)
        if use_coverage:
            # This is a multiplicative factor that is used to add the embeddings of yet to be
            # produced actions to the predicted embedding and bias it.
            self._checklist_embedding_multiplier = Parameter(
                torch.FloatTensor([1.0]))
        # TODO(pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(input_dim, output_dim)
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
 def test_non_normalized_attention_works(self):
     attention = LegacyAttention(normalize=False)
     sentence_tensor = torch.FloatTensor([[[-1, 0, 4], [1, 1, 1],
                                           [-1, 0, 4], [-1, 0, -1]]])
     query_tensor = torch.FloatTensor([[.1, .8, .5]])
     result = attention(query_tensor, sentence_tensor).data.numpy()
     assert_almost_equal(result, [[1.9, 1.4, 1.9, -.6]])
    def test_batched_masked(self):
        attention = LegacyAttention()

        # Testing general masked non-batched case.
        vector = torch.FloatTensor([[0.3, 0.1, 0.5], [0.3, 0.1, 0.5]])
        matrix = torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2],
                                     [0.5, 0.3, 0.2]],
                                    [[0.6, 0.8, 0.1], [0.15, 0.5, 0.2],
                                     [0.5, 0.3, 0.2]]])
        mask = torch.FloatTensor([[1.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
        result = attention(vector, matrix, mask).data.numpy()
        assert_almost_equal(
            result,
            numpy.array([[0.52871835, 0.47128162, 0.0],
                         [0.50749944, 0.0, 0.49250056]]))

        # Test the case where a mask is all 0s and an input is all 0s.
        vector = torch.FloatTensor([[0.0, 0.0, 0.0], [0.3, 0.1, 0.5]])
        matrix = torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2],
                                     [0.5, 0.3, 0.2]],
                                    [[0.6, 0.8, 0.1], [0.15, 0.5, 0.2],
                                     [0.5, 0.3, 0.2]]])
        mask = torch.FloatTensor([[1.0, 1.0, 0.0], [0.0, 0.0, 0.0]])
        result = attention(vector, matrix, mask).data.numpy()
        assert_almost_equal(result,
                            numpy.array([[0.5, 0.5, 0.0], [0.0, 0.0, 0.0]]))
示例#4
0
 def test_masked(self):
     attention = LegacyAttention()
     # Testing general masked non-batched case.
     vector = torch.FloatTensor([[0.3, 0.1, 0.5]])
     matrix = torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2], [0.1, 0.4, 0.3]]])
     mask = torch.FloatTensor([[1.0, 0.0, 1.0]])
     result = attention(vector, matrix, mask).data.numpy()
     assert_almost_equal(result, numpy.array([[0.52248482, 0.0, 0.47751518]]))
 def test_can_build_from_params(self):
     params = Params({
         'similarity_function': {
             'type': 'cosine'
         },
         'normalize': False
     })
     attention = LegacyAttention.from_params(params)
     # pylint: disable=protected-access
     assert attention._similarity_function.__class__.__name__ == 'CosineSimilarity'
     assert attention._normalize is False
示例#6
0
    def test_can_build_from_params(self):
        params = Params({
            "similarity_function": {
                "type": "cosine"
            },
            "normalize": False
        })
        attention = LegacyAttention.from_params(params)

        assert attention._similarity_function.__class__.__name__ == "CosineSimilarity"
        assert attention._normalize is False
    def test_batched_no_mask(self):
        attention = LegacyAttention()

        # Testing general batched case.
        vector = torch.FloatTensor([[0.3, 0.1, 0.5], [0.3, 0.1, 0.5]])
        matrix = torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2]],
                                    [[0.6, 0.8, 0.1], [0.15, 0.5, 0.2]]])

        result = attention(vector, matrix).data.numpy()
        assert_almost_equal(
            result,
            numpy.array([[0.52871835, 0.47128162], [0.52871835, 0.47128162]]))
示例#8
0
    def test_no_mask(self):
        attention = LegacyAttention()

        # Testing general non-batched case.
        vector = Variable(torch.FloatTensor([[0.3, 0.1, 0.5]]))
        matrix = Variable(
            torch.FloatTensor([[[0.6, 0.8, 0.1], [0.15, 0.5, 0.2]]]))

        result = attention(vector, matrix).data.numpy()
        assert_almost_equal(result, numpy.array([[0.52871835, 0.47128162]]))

        # Testing non-batched case where inputs are all 0s.
        vector = Variable(torch.FloatTensor([[0, 0, 0]]))
        matrix = Variable(torch.FloatTensor([[[0, 0, 0], [0, 0, 0]]]))

        result = attention(vector, matrix).data.numpy()
        assert_almost_equal(result, numpy.array([[0.5, 0.5]]))