Example #1
0
def test_seq_mult(shape_1: Tuple[int, ...], num_arrays: int,
                  data: st.DataObject):
    shape_2 = data.draw(hnp.broadcastable_shapes(shape_1), label="shape_2")
    shapes = [shape_1, shape_2]

    pair = shapes

    for i in range(num_arrays):

        # ensure sequence of shapes is mutually-broadcastable
        broadcasted = _broadcast_shapes(*pair)
        shapes.append(
            data.draw(hnp.broadcastable_shapes(broadcasted),
                      label="shape_{}".format(i + 3)))
        pair = [broadcasted, shapes[-1]]

    tensors = [
        Tensor(
            data.draw(
                hnp.arrays(shape=shape,
                           dtype=np.float32,
                           elements=st.floats(-10, 10, width=32))))
        for shape in shapes
    ]
    note("tensors: {}".format(tensors))
    tensors_copy = [x.copy() for x in tensors]

    f = multiply_sequence(*tensors)
    f1 = reduce(lambda x, y: x * y,
                (var for n, var in enumerate(tensors_copy)))

    assert_allclose(f.data, f1.data)

    f.sum().backward()
    f1.sum().backward()

    assert_allclose(f.data, f1.data, rtol=1e-4, atol=1e-4)

    for n, (expected, actual) in enumerate(zip(tensors_copy, tensors)):
        assert_allclose(
            expected.grad,
            actual.grad,
            rtol=1e-3,
            atol=1e-3,
            err_msg="tensor-{}".format(n),
        )

    f.null_gradients()
    assert all(x.grad is None for x in tensors)
    assert all(not x._ops for x in tensors)
Example #2
0
def test_seq_mult():
    a = Tensor(3.)
    b = Tensor([1., 2., 3.])
    c = Tensor([[1., 2., 3.], [2., 3., 4.]])
    f = multiply_sequence(a, b, c, constant=False)
    f.sum().backward()

    a1 = Tensor(3.)
    b1 = Tensor([1., 2., 3.])
    c1 = Tensor([[1., 2., 3.], [2., 3., 4.]])
    f1 = a1 * b1 * c1
    f1.sum().backward()

    assert_allclose(f.data, f1.data)
    assert_allclose(f.grad, f1.grad)
    assert_allclose(a.grad, a1.grad)
    assert_allclose(b.grad, b1.grad)
    assert_allclose(c.grad, c1.grad)
def test_input_validation(arrays):
    with pytest.raises(ValueError):
        add_sequence(*arrays)

    with pytest.raises(ValueError):
        multiply_sequence(*arrays)