예제 #1
0
def test_tensors_handles_constant_strat(constant):
    constants = []
    kwargs = dict(dtype=np.int8, shape=(2, 3))
    if constant is not None:
        kwargs["constant"] = constant

    @given(x=tensors(**kwargs))
    def f(x):
        constants.append(x.constant)

    f()

    assert len(set(constants)) > 1
예제 #2
0
def test_tensors_with_grad(dtype, data: st.DataObject, shape, grad_dtype,
                           grad_elements_bounds):
    tensor = data.draw(
        tensors(
            dtype=dtype,
            shape=shape,
            include_grad=True,
            grad_dtype=grad_dtype,
            grad_elements_bounds=grad_elements_bounds,
        ),
        label="tensor",
    )
    assert isinstance(tensor, Tensor)
    assert tensor.dtype == dtype
    assert isinstance(tensor.grad, np.ndarray)
    assert tensor.grad.shape == tensor.shape
    assert tensor.grad.dtype == (grad_dtype
                                 if grad_dtype is not None else tensor.dtype)
    if grad_elements_bounds is not None:
        assert np.all((100 <= tensor.grad) & (tensor.grad <= 200))
    else:
        assert np.all((-10 <= tensor.grad) & (tensor.grad <= 10))
예제 #3
0
@settings(deadline=None)
@backprop_test_factory(
    mygrad_func=power,
    true_func=np.power,
    index_to_bnds={0: (1, 10), 1: (-3, 3)},
    num_arrays=2,
)
def test_power_bkwd():
    pass


@given(
    t=tensors(
        dtype=np.float64,
        shape=hnp.array_shapes(min_dims=0, min_side=0),
        elements=st.floats(-1e6, 1e6),
        constant=False,
    )
)
def test_x_pow_0_special_case(t):
    y = t ** 0
    y.backward()
    assert_allclose(y.data, np.ones_like(t))
    assert_allclose(t.grad, np.zeros_like(t))


@given(
    t=tensors(
        dtype=np.float64,
        shape=hnp.array_shapes(min_dims=0, min_side=0),
        elements=st.floats(1e-10, 1e6),
예제 #4
0
def test_tensors_dtype(dtype, data: st.DataObject):
    tensor = data.draw(tensors(dtype=dtype, shape=(2, 3)), label="tensor")
    assert isinstance(tensor, Tensor)
    assert tensor.dtype == dtype
    assert tensor.grad is None
예제 #5
0
def test_tensors_shape(shape, data: st.DataObject):
    tensor = data.draw(tensors(np.int8, shape=shape), label="tensor")
    assert isinstance(tensor, Tensor)
    assert tensor.shape == shape
    assert tensor.grad is None
예제 #6
0
def test_tensors_static_constant(constant: bool, data: st.DataObject):
    tensor = data.draw(tensors(np.int8, (2, 3), constant=constant),
                       label="tensor")
    assert isinstance(tensor, Tensor)
    assert tensor.constant is constant
    assert tensor.grad is None