def forward(self,  # pylint: disable=arguments-differ
             matrix_1: torch.Tensor,
             matrix_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply(self._combination,
                                                          [matrix_1.unsqueeze(2), matrix_2.unsqueeze(1)],
                                                          self._weight_vector)
     return self._activation(combined_tensors + self._bias)
示例#2
0
    def test_combine_tensors_and_multiply_with_same_batch_size_and_embedding_dim(self):
        # This test just makes sure we handle some potential edge cases where the lengths of all
        # dimensions are the same, making sure that the multiplication with the weight vector
        # happens along the right dimension (it should be the last one).
        tensors = [torch.Tensor([[[5, 5], [4, 4]], [[2, 3], [1, 1]]])]  # (2, 2, 2)
        weight = torch.Tensor([4, 5])  # (2,)

        combination = "x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[20 + 25, 16 + 20], [8 + 15, 4 + 5]])

        tensors = [torch.Tensor([[[5, 5], [2, 2]], [[4, 4], [3, 3]]]),
                   torch.Tensor([[[2, 3]], [[1, 1]]])]
        weight = torch.Tensor([4, 5])
        combination = "x*y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[5 * 2 * 4 + 5 * 3 * 5, 2 * 2 * 4 + 2 * 3 * 5],
                             [4 * 1 * 4 + 4 * 1 * 5, 3 * 1 * 4 + 3 * 1 * 5]])
示例#3
0
    def test_combine_tensors_and_multiply_with_same_batch_size_and_embedding_dim(self):
        # This test just makes sure we handle some potential edge cases where the lengths of all
        # dimensions are the same, making sure that the multiplication with the weight vector
        # happens along the right dimension (it should be the last one).
        tensors = [torch.Tensor([[[5, 5], [4, 4]], [[2, 3], [1, 1]]])]  # (2, 2, 2)
        weight = torch.Tensor([4, 5])  # (2,)

        combination = "x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[20 + 25, 16 + 20], [8 + 15, 4 + 5]])

        tensors = [torch.Tensor([[[5, 5], [2, 2]], [[4, 4], [3, 3]]]),
                   torch.Tensor([[[2, 3]], [[1, 1]]])]
        weight = torch.Tensor([4, 5])
        combination = "x*y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[5 * 2 * 4 + 5 * 3 * 5, 2 * 2 * 4 + 2 * 3 * 5],
                             [4 * 1 * 4 + 4 * 1 * 5, 3 * 1 * 4 + 3 * 1 * 5]])
示例#4
0
    def test_combine_tensors_and_multiply_with_batch_size_one(self):
        seq_len_1 = 10
        seq_len_2 = 5
        embedding_dim = 8

        combination = "x,y,x*y"
        t1 = torch.randn(1, seq_len_1, embedding_dim)
        t2 = torch.randn(1, seq_len_2, embedding_dim)
        combined_dim = util.get_combined_dim(combination,
                                             [embedding_dim, embedding_dim])
        weight = torch.Tensor(combined_dim)

        result = util.combine_tensors_and_multiply(
            combination, [t1.unsqueeze(2), t2.unsqueeze(1)], weight)

        assert_almost_equal(result.size(), [1, seq_len_1, seq_len_2])
示例#5
0
    def test_combine_tensors_and_multiply_with_batch_size_one_and_seq_len_one(self):
        seq_len_1 = 10
        seq_len_2 = 1
        embedding_dim = 8

        combination = "x,y,x*y"
        t1 = torch.randn(1, seq_len_1, embedding_dim)
        t2 = torch.randn(1, seq_len_2, embedding_dim)
        combined_dim = util.get_combined_dim(combination, [embedding_dim, embedding_dim])
        weight = torch.Tensor(combined_dim)

        result = util.combine_tensors_and_multiply(combination, [t1.unsqueeze(2), t2.unsqueeze(1)], weight)

        assert_almost_equal(
                result.size(),
                [1, seq_len_1, seq_len_2]
        )
示例#6
0
    def test_combine_tensors_and_multiply(self):
        tensors = [torch.Tensor([[[2, 3]]]), torch.Tensor([[[5, 5]]])]
        weight = torch.Tensor([4, 5])

        combination = "x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[8 + 15]])

        combination = "y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[20 + 25]])

        combination = "x,y"
        weight2 = torch.Tensor([4, 5, 4, 5])
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight2),
                            [[8 + 20 + 15 + 25]])

        combination = "x-y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[-3 * 4 + -2 * 5]])

        combination = "y-x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[3 * 4 + 2 * 5]])

        combination = "y+x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[7 * 4 + 8 * 5]])

        combination = "y*x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[10 * 4 + 15 * 5]])

        combination = "y/x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[(5 / 2) * 4 + (5 / 3) * 5]], decimal=4)

        combination = "x/y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[(2 / 5) * 4 + (3 / 5) * 5]], decimal=4)

        with pytest.raises(ConfigurationError):
            util.combine_tensors_and_multiply("x+y+y", tensors, weight)

        with pytest.raises(ConfigurationError):
            util.combine_tensors_and_multiply("x%y", tensors, weight)
 def forward(self, matrix_1: torch.Tensor,
             matrix_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply(
         self._combination, [matrix_1.unsqueeze(2),
                             matrix_2.unsqueeze(1)], self._weight_vector)
     return self._activation(combined_tensors + self._bias)
 def _forward_internal(self, vector: torch.Tensor,
                       matrix: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply(
         self._combination, [vector.unsqueeze(1), matrix],
         self._weight_vector)
     return self._activation(combined_tensors.squeeze(1) + self._bias)
示例#9
0
 def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply(self._combination,
                                                          [vector.unsqueeze(1), matrix],
                                                          self._weight_vector)
     return self._activation(combined_tensors.squeeze(1) + self._bias)
示例#10
0
    def test_combine_tensors_and_multiply(self):
        tensors = [torch.Tensor([[[2, 3]]]), torch.Tensor([[[5, 5]]])]
        weight = torch.Tensor([4, 5])

        combination = "x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[8 + 15]])

        combination = "y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[20 + 25]])

        combination = "x,y"
        weight2 = torch.Tensor([4, 5, 4, 5])
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight2),
                            [[8 + 20 + 15 + 25]])

        combination = "x-y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[-3 * 4 + -2 * 5]])

        combination = "y-x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[3 * 4 + 2 * 5]])

        combination = "y+x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[7 * 4 + 8 * 5]])

        combination = "y*x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[10 * 4 + 15 * 5]])

        combination = "y/x"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[(5 / 2) * 4 + (5 / 3) * 5]], decimal=4)

        combination = "x/y"
        assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight),
                            [[(2 / 5) * 4 + (3 / 5) * 5]], decimal=4)

        with pytest.raises(ConfigurationError):
            util.combine_tensors_and_multiply("x+y+y", tensors, weight)

        with pytest.raises(ConfigurationError):
            util.combine_tensors_and_multiply("x%y", tensors, weight)
示例#11
0
文件: tmp.py 项目: PYART0/PyART-demo
 def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor:
     combined_tensors = util.combine_tensors_and_multiply()
     reveal_type(matrix_1)