def test_input_validation(): x = np.array([1, 2]) with pytest.raises(TypeError): reshape(x) with pytest.raises(TypeError): reshape(x, (2, ), 2)
def test_reshape_fwd(a): new_shape = gen_shape(a.size) x = Tensor(a) x = reshape(x, new_shape, constant=True) a = a.reshape(new_shape) assert x.shape == a.shape, "Tensor.reshape failed" assert_allclose(a, x.data), "Tensor.reshape failed"
def test_reshape_backward(a): new_shape = gen_shape(a.size) grad = np.arange(a.size).reshape(a.shape) x = Tensor(a) o = reshape(x, new_shape, constant=False) o.backward(grad.reshape(new_shape)) assert x.grad.shape == grad.shape assert_allclose(x.grad, grad)
def test_input_validation_matches_numpy(): try: np.reshape(np.array(1), *(1, 1)) except Exception: with pytest.raises(Exception): reshape(Tensor(1), *(1, 1))