コード例 #1
0
    def test_batch_matmul(self):
        left_interp_indices = torch.tensor([[2, 3], [3, 4], [4, 5]],
                                           dtype=torch.long).repeat(5, 3, 1)
        left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]],
                                          dtype=torch.float).repeat(5, 3, 1)
        left_interp_values_copy = left_interp_values.clone()
        left_interp_values.requires_grad = True
        left_interp_values_copy.requires_grad = True
        right_interp_indices = torch.tensor([[0, 1], [1, 2], [2, 3]],
                                            dtype=torch.long).repeat(5, 3, 1)
        right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]],
                                           dtype=torch.float).repeat(5, 3, 1)
        right_interp_values_copy = right_interp_values.clone()
        right_interp_values.requires_grad = True
        right_interp_values_copy.requires_grad = True

        base_lazy_tensor_mat = torch.randn(5, 6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.transpose(
            -1, -2).matmul(base_lazy_tensor_mat)
        base_tensor = base_lazy_tensor_mat
        base_tensor_copy = base_tensor.clone()
        base_tensor.requires_grad = True
        base_tensor_copy.requires_grad = True
        base_lazy_tensor = NonLazyTensor(base_tensor)

        test_matrix = torch.randn(5, 9, 4)

        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    right_interp_indices,
                                                    right_interp_values)
        res = interp_lazy_tensor.matmul(test_matrix)

        left_matrix_comps = []
        right_matrix_comps = []
        for i in range(5):
            left_matrix_comp = torch.zeros(9, 6)
            right_matrix_comp = torch.zeros(9, 6)
            left_matrix_comp.scatter_(1, left_interp_indices[i],
                                      left_interp_values_copy[i])
            right_matrix_comp.scatter_(1, right_interp_indices[i],
                                       right_interp_values_copy[i])
            left_matrix_comps.append(left_matrix_comp.unsqueeze(0))
            right_matrix_comps.append(right_matrix_comp.unsqueeze(0))
        left_matrix = torch.cat(left_matrix_comps)
        right_matrix = torch.cat(right_matrix_comps)

        actual = left_matrix.matmul(base_tensor_copy).matmul(
            right_matrix.transpose(-1, -2))
        actual = actual.matmul(test_matrix)
        self.assertTrue(approx_equal(res, actual))

        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(approx_equal(base_tensor.grad, base_tensor_copy.grad))
        self.assertTrue(
            approx_equal(left_interp_values.grad,
                         left_interp_values_copy.grad))
コード例 #2
0
    def test_matmul_batch(self):
        left_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                [4, 5]]).repeat(5, 3, 1)
        left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]],
                                          dtype=torch.float).repeat(5, 3, 1)
        right_interp_indices = torch.LongTensor([[0, 1], [1, 2],
                                                 [2, 3]]).repeat(5, 3, 1)
        right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]],
                                           dtype=torch.float).repeat(5, 3, 1)

        base_lazy_tensor_mat = torch.randn(5, 6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.transpose(
            1, 2).matmul(base_lazy_tensor_mat)
        base_lazy_tensor_mat.requires_grad = True
        test_matrix = torch.randn(1, 9, 4)

        base_lazy_tensor = NonLazyTensor(base_lazy_tensor_mat)
        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    right_interp_indices,
                                                    right_interp_values)
        res = interp_lazy_tensor.matmul(test_matrix)

        left_matrix = torch.tensor(
            [
                [0, 0, 1, 2, 0, 0],
                [0, 0, 0, 0.5, 1, 0],
                [0, 0, 0, 0, 1, 3],
                [0, 0, 1, 2, 0, 0],
                [0, 0, 0, 0.5, 1, 0],
                [0, 0, 0, 0, 1, 3],
                [0, 0, 1, 2, 0, 0],
                [0, 0, 0, 0.5, 1, 0],
                [0, 0, 0, 0, 1, 3],
            ],
            dtype=torch.float,
        ).repeat(5, 1, 1)

        right_matrix = torch.tensor(
            [
                [1, 2, 0, 0, 0, 0],
                [0, 2, 0.5, 0, 0, 0],
                [0, 0, 1, 3, 0, 0],
                [1, 2, 0, 0, 0, 0],
                [0, 2, 0.5, 0, 0, 0],
                [0, 0, 1, 3, 0, 0],
                [1, 2, 0, 0, 0, 0],
                [0, 2, 0.5, 0, 0, 0],
                [0, 0, 1, 3, 0, 0],
            ],
            dtype=torch.float,
        ).repeat(5, 1, 1)
        actual = left_matrix.matmul(base_lazy_tensor_mat).matmul(
            right_matrix.transpose(-1, -2)).matmul(test_matrix)

        self.assertTrue(approx_equal(res, actual))
コード例 #3
0
    def test_matmul(self):
        left_interp_indices = torch.LongTensor([[2, 3], [3, 4],
                                                [4, 5]]).repeat(3, 1)
        left_interp_values = torch.tensor([[1, 2], [0.5, 1], [1, 3]],
                                          dtype=torch.float).repeat(3, 1)
        left_interp_values_copy = left_interp_values.clone()
        left_interp_values.requires_grad = True
        left_interp_values_copy.requires_grad = True
        right_interp_indices = torch.LongTensor([[0, 1], [1, 2],
                                                 [2, 3]]).repeat(3, 1)
        right_interp_values = torch.tensor([[1, 2], [2, 0.5], [1, 3]],
                                           dtype=torch.float).repeat(3, 1)
        right_interp_values_copy = right_interp_values.clone()
        right_interp_values.requires_grad = True
        right_interp_values_copy.requires_grad = True

        base_lazy_tensor_mat = torch.randn(6, 6)
        base_lazy_tensor_mat = base_lazy_tensor_mat.t().matmul(
            base_lazy_tensor_mat)
        base_tensor = base_lazy_tensor_mat
        base_tensor.requires_grad = True
        base_tensor_copy = base_lazy_tensor_mat
        base_lazy_tensor = NonLazyTensor(base_tensor)

        test_matrix = torch.randn(9, 4)

        interp_lazy_tensor = InterpolatedLazyTensor(base_lazy_tensor,
                                                    left_interp_indices,
                                                    left_interp_values,
                                                    right_interp_indices,
                                                    right_interp_values)
        res = interp_lazy_tensor.matmul(test_matrix)

        left_matrix = torch.zeros(9, 6)
        right_matrix = torch.zeros(9, 6)
        left_matrix.scatter_(1, left_interp_indices, left_interp_values_copy)
        right_matrix.scatter_(1, right_interp_indices,
                              right_interp_values_copy)

        actual = left_matrix.matmul(base_tensor_copy).matmul(
            right_matrix.t()).matmul(test_matrix)
        self.assertTrue(approx_equal(res, actual))

        res.sum().backward()
        actual.sum().backward()

        self.assertTrue(approx_equal(base_tensor.grad, base_tensor_copy.grad))
        self.assertTrue(
            approx_equal(left_interp_values.grad,
                         left_interp_values_copy.grad))