def test_reshape_method_backward(a): new_shape = gen_shape(a.size) grad = np.arange(a.size).reshape(a.shape) x = Tensor(a) o = x.reshape(new_shape) o.backward(grad.reshape(new_shape)) assert x.grad.shape == grad.shape assert_allclose(x.grad, grad)
def test_input_validation(bad_input): x = Tensor([1, 2]) with pytest.raises(TypeError): x.reshape(*bad_input)