def test_tensors_handles_constant_strat(constant): constants = [] kwargs = dict(dtype=np.int8, shape=(2, 3)) if constant is not None: kwargs["constant"] = constant @given(x=tensors(**kwargs)) def f(x): constants.append(x.constant) f() assert len(set(constants)) > 1
def test_tensors_with_grad(dtype, data: st.DataObject, shape, grad_dtype, grad_elements_bounds): tensor = data.draw( tensors( dtype=dtype, shape=shape, include_grad=True, grad_dtype=grad_dtype, grad_elements_bounds=grad_elements_bounds, ), label="tensor", ) assert isinstance(tensor, Tensor) assert tensor.dtype == dtype assert isinstance(tensor.grad, np.ndarray) assert tensor.grad.shape == tensor.shape assert tensor.grad.dtype == (grad_dtype if grad_dtype is not None else tensor.dtype) if grad_elements_bounds is not None: assert np.all((100 <= tensor.grad) & (tensor.grad <= 200)) else: assert np.all((-10 <= tensor.grad) & (tensor.grad <= 10))
@settings(deadline=None) @backprop_test_factory( mygrad_func=power, true_func=np.power, index_to_bnds={0: (1, 10), 1: (-3, 3)}, num_arrays=2, ) def test_power_bkwd(): pass @given( t=tensors( dtype=np.float64, shape=hnp.array_shapes(min_dims=0, min_side=0), elements=st.floats(-1e6, 1e6), constant=False, ) ) def test_x_pow_0_special_case(t): y = t ** 0 y.backward() assert_allclose(y.data, np.ones_like(t)) assert_allclose(t.grad, np.zeros_like(t)) @given( t=tensors( dtype=np.float64, shape=hnp.array_shapes(min_dims=0, min_side=0), elements=st.floats(1e-10, 1e6),
def test_tensors_dtype(dtype, data: st.DataObject): tensor = data.draw(tensors(dtype=dtype, shape=(2, 3)), label="tensor") assert isinstance(tensor, Tensor) assert tensor.dtype == dtype assert tensor.grad is None
def test_tensors_shape(shape, data: st.DataObject): tensor = data.draw(tensors(np.int8, shape=shape), label="tensor") assert isinstance(tensor, Tensor) assert tensor.shape == shape assert tensor.grad is None
def test_tensors_static_constant(constant: bool, data: st.DataObject): tensor = data.draw(tensors(np.int8, (2, 3), constant=constant), label="tensor") assert isinstance(tensor, Tensor) assert tensor.constant is constant assert tensor.grad is None