def forward(self, # pylint: disable=arguments-differ matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors_and_multiply(self._combination, [matrix_1.unsqueeze(2), matrix_2.unsqueeze(1)], self._weight_vector) return self._activation(combined_tensors + self._bias)
def test_combine_tensors_and_multiply_with_same_batch_size_and_embedding_dim(self): # This test just makes sure we handle some potential edge cases where the lengths of all # dimensions are the same, making sure that the multiplication with the weight vector # happens along the right dimension (it should be the last one). tensors = [torch.Tensor([[[5, 5], [4, 4]], [[2, 3], [1, 1]]])] # (2, 2, 2) weight = torch.Tensor([4, 5]) # (2,) combination = "x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[20 + 25, 16 + 20], [8 + 15, 4 + 5]]) tensors = [torch.Tensor([[[5, 5], [2, 2]], [[4, 4], [3, 3]]]), torch.Tensor([[[2, 3]], [[1, 1]]])] weight = torch.Tensor([4, 5]) combination = "x*y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[5 * 2 * 4 + 5 * 3 * 5, 2 * 2 * 4 + 2 * 3 * 5], [4 * 1 * 4 + 4 * 1 * 5, 3 * 1 * 4 + 3 * 1 * 5]])
def test_combine_tensors_and_multiply_with_same_batch_size_and_embedding_dim(self): # This test just makes sure we handle some potential edge cases where the lengths of all # dimensions are the same, making sure that the multiplication with the weight vector # happens along the right dimension (it should be the last one). tensors = [torch.Tensor([[[5, 5], [4, 4]], [[2, 3], [1, 1]]])] # (2, 2, 2) weight = torch.Tensor([4, 5]) # (2,) combination = "x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[20 + 25, 16 + 20], [8 + 15, 4 + 5]]) tensors = [torch.Tensor([[[5, 5], [2, 2]], [[4, 4], [3, 3]]]), torch.Tensor([[[2, 3]], [[1, 1]]])] weight = torch.Tensor([4, 5]) combination = "x*y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[5 * 2 * 4 + 5 * 3 * 5, 2 * 2 * 4 + 2 * 3 * 5], [4 * 1 * 4 + 4 * 1 * 5, 3 * 1 * 4 + 3 * 1 * 5]])
def test_combine_tensors_and_multiply_with_batch_size_one(self): seq_len_1 = 10 seq_len_2 = 5 embedding_dim = 8 combination = "x,y,x*y" t1 = torch.randn(1, seq_len_1, embedding_dim) t2 = torch.randn(1, seq_len_2, embedding_dim) combined_dim = util.get_combined_dim(combination, [embedding_dim, embedding_dim]) weight = torch.Tensor(combined_dim) result = util.combine_tensors_and_multiply( combination, [t1.unsqueeze(2), t2.unsqueeze(1)], weight) assert_almost_equal(result.size(), [1, seq_len_1, seq_len_2])
def test_combine_tensors_and_multiply_with_batch_size_one_and_seq_len_one(self): seq_len_1 = 10 seq_len_2 = 1 embedding_dim = 8 combination = "x,y,x*y" t1 = torch.randn(1, seq_len_1, embedding_dim) t2 = torch.randn(1, seq_len_2, embedding_dim) combined_dim = util.get_combined_dim(combination, [embedding_dim, embedding_dim]) weight = torch.Tensor(combined_dim) result = util.combine_tensors_and_multiply(combination, [t1.unsqueeze(2), t2.unsqueeze(1)], weight) assert_almost_equal( result.size(), [1, seq_len_1, seq_len_2] )
def test_combine_tensors_and_multiply(self): tensors = [torch.Tensor([[[2, 3]]]), torch.Tensor([[[5, 5]]])] weight = torch.Tensor([4, 5]) combination = "x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[8 + 15]]) combination = "y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[20 + 25]]) combination = "x,y" weight2 = torch.Tensor([4, 5, 4, 5]) assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight2), [[8 + 20 + 15 + 25]]) combination = "x-y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[-3 * 4 + -2 * 5]]) combination = "y-x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[3 * 4 + 2 * 5]]) combination = "y+x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[7 * 4 + 8 * 5]]) combination = "y*x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[10 * 4 + 15 * 5]]) combination = "y/x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[(5 / 2) * 4 + (5 / 3) * 5]], decimal=4) combination = "x/y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[(2 / 5) * 4 + (3 / 5) * 5]], decimal=4) with pytest.raises(ConfigurationError): util.combine_tensors_and_multiply("x+y+y", tensors, weight) with pytest.raises(ConfigurationError): util.combine_tensors_and_multiply("x%y", tensors, weight)
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors_and_multiply( self._combination, [matrix_1.unsqueeze(2), matrix_2.unsqueeze(1)], self._weight_vector) return self._activation(combined_tensors + self._bias)
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors_and_multiply( self._combination, [vector.unsqueeze(1), matrix], self._weight_vector) return self._activation(combined_tensors.squeeze(1) + self._bias)
def _forward_internal(self, vector: torch.Tensor, matrix: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors_and_multiply(self._combination, [vector.unsqueeze(1), matrix], self._weight_vector) return self._activation(combined_tensors.squeeze(1) + self._bias)
def test_combine_tensors_and_multiply(self): tensors = [torch.Tensor([[[2, 3]]]), torch.Tensor([[[5, 5]]])] weight = torch.Tensor([4, 5]) combination = "x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[8 + 15]]) combination = "y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[20 + 25]]) combination = "x,y" weight2 = torch.Tensor([4, 5, 4, 5]) assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight2), [[8 + 20 + 15 + 25]]) combination = "x-y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[-3 * 4 + -2 * 5]]) combination = "y-x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[3 * 4 + 2 * 5]]) combination = "y+x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[7 * 4 + 8 * 5]]) combination = "y*x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[10 * 4 + 15 * 5]]) combination = "y/x" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[(5 / 2) * 4 + (5 / 3) * 5]], decimal=4) combination = "x/y" assert_almost_equal(util.combine_tensors_and_multiply(combination, tensors, weight), [[(2 / 5) * 4 + (3 / 5) * 5]], decimal=4) with pytest.raises(ConfigurationError): util.combine_tensors_and_multiply("x+y+y", tensors, weight) with pytest.raises(ConfigurationError): util.combine_tensors_and_multiply("x%y", tensors, weight)
def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tensor: combined_tensors = util.combine_tensors_and_multiply() reveal_type(matrix_1)