Beispiel #1
0
    def test_backward(self):
        mm = cpp_extension.MatrixMultiplier(4, 8)
        weights = torch.rand(8, 4, requires_grad=True)
        result = mm.forward(weights)
        result.sum().backward()
        tensor = mm.get()

        expected_weights_grad = tensor.t().mm(torch.ones([4, 4]))
        self.assertEqual(weights.grad, expected_weights_grad)

        expected_tensor_grad = torch.ones([4, 4]).mm(weights.t())
        self.assertEqual(tensor.grad, expected_tensor_grad)
Beispiel #2
0
 def test_extension_module(self):
     mm = cpp_extension.MatrixMultiplier(4, 8)
     weights = torch.rand(8, 4)
     expected = mm.get().mm(weights)
     result = mm.forward(weights)
     self.assertEqual(expected, result)