def test_gradient_multiinput(self): model = BasicModel6_MultiTensor() input1 = torch.tensor([[-3.0, -5.0]], requires_grad=True) input2 = torch.tensor([[-5.0, 2.0]], requires_grad=True) grads = compute_gradients(model, (input1, input2)) assertArraysAlmostEqual(grads[0].squeeze(0).tolist(), [0.0, 1.0], delta=0.01) assertArraysAlmostEqual(grads[1].squeeze(0).tolist(), [0.0, 1.0], delta=0.01)
def _get_multiargs_basic_config(): model = BasicModel5_MultiArgs() additional_forward_args = ([2, 3], 1) inputs = ( torch.tensor([[1.5, 2.0, 34.3], [3.4, 1.2, 2.0]], requires_grad=True), torch.tensor([[3.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True), ) grads = compute_gradients( model, inputs, additional_forward_args=additional_forward_args ) return model, inputs, grads, additional_forward_args
def test_gradient_basic_2(self): model = BasicModel() input = torch.tensor([[-3.0]], requires_grad=True) grads = compute_gradients(model, input)[0] assertArraysAlmostEqual(grads.squeeze(0).tolist(), [1.0], delta=0.01)