Beispiel #1
0
def test_backprop_multiple_graphs_double_backprop(method0, method1):
    shape = (1,)
    dtype = chainerx.float32

    with chainerx.backprop_scope('bp_x1') as bp_x1, \
            chainerx.backprop_scope('bp_x0') as bp_x0:
        xs = (
            chainerx.full(shape, 2, dtype).require_grad(bp_x0),
            chainerx.full(shape, 3, dtype).require_grad(bp_x1),)
        expected_gxs = (
            None,
            chainerx.full(shape, 2, dtype),)

        def fprop(x0, x1):
            assert x0.is_grad_required(bp_x0)

            h = x0 * (x0 + x1)
            if method0 == 'backward':
                chainerx.backward(h, backprop_id=bp_x0)
                gx0 = x0.get_grad(bp_x0)
            elif method0 == 'grad':
                gx0, = chainerx.grad([h], [x0], backprop_id=bp_x0)
            else:
                assert False

            assert not gx0.is_backprop_required(bp_x0)
            assert gx0.is_backprop_required(bp_x1)

            return x0 * gx0,

        _check_backprop(method1, fprop, xs, expected_gxs, backprop_id=bp_x1)
Beispiel #2
0
def test_array_require_grad_multiple_graphs_forward():
    x1 = chainerx.array([1, 1, 1], chainerx.float32)
    x2 = chainerx.array([1, 1, 1], chainerx.float32)

    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2, \
            chainerx.backprop_scope('bp3') as bp3:

        x1.require_grad(bp1)
        x2.require_grad(bp2)

        assert x1.is_grad_required(bp1)
        assert x2.is_grad_required(bp2)
        assert x1.is_backprop_required(bp1)
        assert x2.is_backprop_required(bp2)

        assert not x1.is_grad_required(bp2)
        assert not x2.is_grad_required(bp1)
        assert not x1.is_backprop_required(bp2)
        assert not x2.is_backprop_required(bp1)

        y = x1 * x2

        assert not y.is_grad_required(bp1)
        assert not y.is_grad_required(bp2)
        assert y.is_backprop_required(bp1)
        assert y.is_backprop_required(bp2)

        # No unspecified graphs are generated
        assert not y.is_backprop_required(None)
        assert not y.is_backprop_required(bp3)
Beispiel #3
0
def test_array_require_grad_with_backprop_id():
    array = chainerx.array([1, 1, 1], chainerx.float32)

    with chainerx.backprop_scope('bp1') as bp1:
        assert not array.is_backprop_required(bp1)
        array.require_grad(bp1)
        assert array.is_grad_required(bp1)
        assert array.is_backprop_required(bp1)

        # Repeated calls should not fail, but do nothing
        array.require_grad(bp1)
        assert array.is_grad_required(bp1)
        assert array.is_backprop_required(bp1)

    # keyword arguments
    with chainerx.backprop_scope('bp2') as bp2:
        assert not array.is_backprop_required(backprop_id=bp2)
        array.require_grad(backprop_id=bp2)
        assert array.is_grad_required(bp2)
        assert array.is_grad_required(backprop_id=bp2)
        assert array.is_backprop_required(bp2)
        assert array.is_backprop_required(backprop_id=bp2)

        # Repeated calls should not fail, but do nothing
        array.require_grad(backprop_id=bp2)
        assert array.is_grad_required(backprop_id=bp2)
        assert array.is_backprop_required(backprop_id=bp2)
Beispiel #4
0
def test_as_grad_stopped_view(shape, float_dtype):
    dtype = float_dtype

    # Stop gradients on all graphs
    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2, \
            chainerx.backprop_scope('bp3') as bp3:

        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        b = a.as_grad_stopped(copy=False)

        chainerx.testing.assert_array_equal_ex(a, b)
        assert b.device is a.device
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)

        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)

        # Stop gradients on some graphs
        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        a.require_grad(bp3)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
        b = a.as_grad_stopped([bp1, bp2], copy=False)

        chainerx.testing.assert_array_equal_ex(a, b)
        assert b.device is a.device
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_grad_required(bp3)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)
        assert b.is_backprop_required(bp3)

        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
Beispiel #5
0
def test_backprop_multiple_graphs_reuse(method0, method1, method2):
    shape = (1,)
    dtype = chainerx.float32

    def fprop(x0, x1):
        return x0 * x1,

    with chainerx.backprop_scope('bp2') as backprop_id2, \
            chainerx.backprop_scope('bp1') as backprop_id1:
        xs = (
            chainerx.full(shape, 2, dtype).require_grad(backprop_id1),
            chainerx.full(shape, 5, dtype).require_grad(backprop_id2),)
        expected_gxs = (
            chainerx.full(shape, 5, dtype),
            None,)

        _check_backprop(
            method0, fprop, xs, expected_gxs, backprop_id=backprop_id1)

        x1, x2 = xs
        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id2) is None

        expected_gxs = (
            None,
            chainerx.full(shape, 2, dtype),)

        _check_backprop(
            method1, fprop, xs, expected_gxs, backprop_id=backprop_id2)

        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        x1.require_grad(backprop_id2)
        x2.require_grad(backprop_id1)

        expected_gxs = (
            chainerx.full(shape, 5, dtype),
            chainerx.full(shape, 2, dtype),)

        _check_backprop(
            method2, fprop, xs, expected_gxs, backprop_id=backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id1) is None
Beispiel #6
0
def test_backward_multiple_graphs_non_existing():
    shape = (1,)
    dtype = chainerx.float32

    x1 = chainerx.full(shape, 2, dtype)
    x2 = chainerx.full(shape, 5, dtype)

    with chainerx.backprop_scope('bp1') as backprop_id1, \
            chainerx.backprop_scope('bp2') as backprop_id2:

        x1.require_grad(backprop_id1)
        x2.require_grad(backprop_id1)

        y = x1 * x2
        with pytest.raises(chainerx.ChainerxError):
            chainerx.backward(y, backprop_id2)
Beispiel #7
0
def test_is_backprop_required():
    current_context = chainerx.get_default_context()
    another_context = chainerx.Context()

    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2:
        with chainerx.no_backprop_mode():
            with chainerx.force_backprop_mode(bp1):
                assert not chainerx.is_backprop_required()
                assert chainerx.is_backprop_required(bp1)
                assert not chainerx.is_backprop_required(bp2)
                assert not chainerx.is_backprop_required(
                    context=current_context)
                assert chainerx.is_backprop_required(context=another_context)

        with pytest.raises(TypeError):
            chainerx.is_backprop_required(context='foo')
Beispiel #8
0
def test_backward_keyword_arguments():
    x = chainerx.full((1,), 2, chainerx.float32)
    with chainerx.backprop_scope('bp1') as backprop_id1:
        x.require_grad(backprop_id=backprop_id1)
        chainerx.backward(x, backprop_id=backprop_id1)
        with pytest.raises(
                TypeError, match=r'.*incompatible function arguments.*'):
            chainerx.backward(body=x, backprop_id=backprop_id1)
Beispiel #9
0
def test_backprop_multiple_graphs_basic(method):
    shape = (1,)
    dtype = chainerx.float32

    with chainerx.backprop_scope('bp1') as backprop_id1, \
            chainerx.backprop_scope('bp2') as backprop_id2:
        xs = (
            chainerx.full(shape, 2, dtype).require_grad(backprop_id1),
            chainerx.full(shape, 5, dtype).require_grad(backprop_id2),)
        expected_gxs = (
            chainerx.full(shape, 5, dtype),
            None,)

        def fprop(x0, x1):
            return x0 * x1,

        _check_backprop(
            method, fprop, xs, expected_gxs, backprop_id=backprop_id1)
Beispiel #10
0
def test_backward_multiple_graphs_reuse():
    shape = (1,)
    dtype = chainerx.float32

    x1 = chainerx.full(shape, 2, dtype)
    x2 = chainerx.full(shape, 5, dtype)

    with chainerx.backprop_scope('bp2') as backprop_id2, \
            chainerx.backprop_scope('bp1') as backprop_id1:

        x1.require_grad(backprop_id1)
        x2.require_grad(backprop_id2)

        xs = (x1, x2)

        def fprop(xs_, extra_xs_):
            x1, x2 = xs_
            y = x1 * x2
            return y,

        expected_gxs = (chainerx.full(shape, 5, dtype), chainerx.ChainerxError)
        _check_backprop(xs, expected_gxs, fprop, (), backprop_id=backprop_id1)

        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id2) is None

        expected_gxs = (chainerx.ChainerxError, chainerx.full(shape, 2, dtype))
        _check_backprop(xs, expected_gxs, fprop, (), backprop_id=backprop_id2)

        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        x1.require_grad(backprop_id2)
        x2.require_grad(backprop_id1)

        expected_gxs = (chainerx.full(shape, 5, dtype),
                        chainerx.full(shape, 2, dtype))
        _check_backprop(xs, expected_gxs, fprop, (), backprop_id=backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id1) is None
Beispiel #11
0
def test_force_backprop_mode():
    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2:
        with chainerx.no_backprop_mode():
            assert not chainerx.is_backprop_required()
            assert not chainerx.is_backprop_required(bp1)
            assert not chainerx.is_backprop_required(bp2)

            with chainerx.force_backprop_mode():
                assert chainerx.is_backprop_required()
                assert chainerx.is_backprop_required(bp1)
                assert chainerx.is_backprop_required(bp2)
            assert not chainerx.is_backprop_required()
            assert not chainerx.is_backprop_required(bp1)
            assert not chainerx.is_backprop_required(bp2)

            with chainerx.force_backprop_mode(chainerx.get_default_context()):
                assert chainerx.is_backprop_required()
                assert chainerx.is_backprop_required(bp1)
                assert chainerx.is_backprop_required(bp2)
            assert not chainerx.is_backprop_required()
            assert not chainerx.is_backprop_required(bp1)
            assert not chainerx.is_backprop_required(bp2)

            with chainerx.force_backprop_mode(bp1):
                assert not chainerx.is_backprop_required()
                assert chainerx.is_backprop_required(bp1)
                assert not chainerx.is_backprop_required(bp2)
            assert not chainerx.is_backprop_required()
            assert not chainerx.is_backprop_required(bp1)
            assert not chainerx.is_backprop_required(bp2)

            with chainerx.force_backprop_mode((bp1, bp2)):
                assert not chainerx.is_backprop_required()
                assert chainerx.is_backprop_required(bp1)
                assert chainerx.is_backprop_required(bp2)
            assert not chainerx.is_backprop_required()
            assert not chainerx.is_backprop_required(bp1)
            assert not chainerx.is_backprop_required(bp2)

        with chainerx.force_backprop_mode():
            assert chainerx.is_backprop_required()
            assert chainerx.is_backprop_required(bp1)
            assert chainerx.is_backprop_required(bp2)
Beispiel #12
0
def test_backprop_multiple_graphs_non_existing(method):
    shape = (1,)
    dtype = chainerx.float32

    with chainerx.backprop_scope('bp1') as backprop_id1, \
            chainerx.backprop_scope('bp2') as backprop_id2:
        xs = (
            chainerx.full(shape, 2, dtype).require_grad(backprop_id1),
            chainerx.full(shape, 5, dtype).require_grad(backprop_id1),)

        y = xs[0] * xs[1]

        with pytest.raises(chainerx.ChainerxError):
            if method == 'backward':
                chainerx.backward(y, backprop_id2)
            elif method == 'grad':
                chainerx.grad([y], xs, backprop_id2)
            else:
                assert False
Beispiel #13
0
def test_backward_multiple_graphs_basic():
    shape = (1,)
    dtype = chainerx.float32

    x1 = chainerx.full(shape, 2, dtype)
    x2 = chainerx.full(shape, 5, dtype)

    with chainerx.backprop_scope('bp1') as backprop_id1, \
            chainerx.backprop_scope('bp2') as backprop_id2:

        x1.require_grad(backprop_id1)
        x2.require_grad(backprop_id2)

        xs = (x1, x2)
        expected_gxs = (chainerx.full(shape, 5, dtype), chainerx.ChainerxError)

        def fprop(xs_, extra_xs_):
            x1, x2 = xs_
            y = x1 * x2
            return y,

        _check_backprop(xs, expected_gxs, fprop, (), backprop_id=backprop_id1)
Beispiel #14
0
def test_array_grad_with_backprop_id():
    array = chainerx.array([1., 1., 1.], chainerx.float32)
    grad = chainerx.array([0.5, 0.5, 0.5], chainerx.float32)

    with chainerx.backprop_scope('bp1') as bp1:
        with pytest.raises(chainerx.ChainerxError):
            array.get_grad(bp1)
        with pytest.raises(chainerx.ChainerxError):
            array.set_grad(grad, bp1)
        with pytest.raises(chainerx.ChainerxError):
            array.cleargrad(bp1)

        array.require_grad(bp1).set_grad(grad, bp1)
        assert array.get_grad(bp1) is not None
        assert array.get_grad(bp1)._debug_flat_data == grad._debug_flat_data

        array.cleargrad(bp1)  # clear
        assert array.get_grad(bp1) is None

    # keyword arguments
    with chainerx.backprop_scope('bp2') as bp2:
        with pytest.raises(chainerx.ChainerxError):
            array.get_grad(backprop_id=bp2)
        with pytest.raises(chainerx.ChainerxError):
            array.set_grad(grad, backprop_id=bp2)
        with pytest.raises(chainerx.ChainerxError):
            array.cleargrad(backprop_id=bp2)

        array.require_grad(backprop_id=bp2).set_grad(grad, backprop_id=bp2)
        assert array.get_grad(bp2) is not None
        assert array.get_grad(backprop_id=bp2) is not None
        assert array.get_grad(bp2)._debug_flat_data == grad._debug_flat_data
        assert array.get_grad(
            backprop_id=bp2)._debug_flat_data == grad._debug_flat_data

        array.cleargrad(backprop_id=bp2)  # clear
        assert array.get_grad(bp2) is None
        assert array.get_grad(backprop_id=bp2) is None
Beispiel #15
0
def test_multiple_graphs_double_backprop():
    with chainerx.backprop_scope('bp_y') as bp_y, \
            chainerx.backprop_scope('bp_x') as bp_x:

        x = chainerx.full((1,), 2, chainerx.float32)
        x.require_grad(backprop_id=bp_x)

        y = chainerx.full((1,), 3, chainerx.float32)
        y.require_grad(backprop_id=bp_y)

        z = x * (x + y)
        chainerx.backward(z, backprop_id=bp_x)

        gx = x.get_grad(bp_x)  # 2x + y
        assert not gx.is_backprop_required(backprop_id=bp_x)
        assert gx.is_backprop_required(backprop_id=bp_y)

        w = x * gx
        chainerx.backward(w, backprop_id=bp_y)

        e = chainerx.full((1,), 2, chainerx.float32)

        _assert_arrays_equal(y.get_grad(bp_y), e)  # x
Beispiel #16
0
def test_array_backward():
    with chainerx.backprop_scope('bp1') as bp1:
        x1 = chainerx.array(
            [1, 1, 1], chainerx.float32).require_grad(backprop_id=bp1)
        x2 = chainerx.array(
            [1, 1, 1], chainerx.float32).require_grad(backprop_id=bp1)
        y = x1 * x2

        y.backward(backprop_id=bp1, enable_double_backprop=True)
        gx1 = x1.get_grad(backprop_id=bp1)
        x1.set_grad(None, backprop_id=bp1)

        gx1.backward(backprop_id=bp1)
        with pytest.raises(chainerx.ChainerxError):
            gx1.get_grad(backprop_id=bp1)
Beispiel #17
0
def test_array_backward():
    with chainerx.backprop_scope('bp1') as bp1:
        x1 = chainerx.array(
            [1, 1, 1], chainerx.float32).require_grad(backprop_id=bp1)
        x2 = chainerx.array(
            [1, 1, 1], chainerx.float32).require_grad(backprop_id=bp1)
        y = x1 * x2

        y.backward(backprop_id=bp1, enable_double_backprop=True)
        gx1 = x1.get_grad(backprop_id=bp1)
        x1.set_grad(None, backprop_id=bp1)

        gx1.backward(backprop_id=bp1)
        with pytest.raises(chainerx.ChainerxError):
            gx1.get_grad(backprop_id=bp1)
Beispiel #18
0
def test_array_repr_expired_backprop_id():
    with chainerx.backprop_scope('bp1') as bp1:
        array = chainerx.array([3.0], chainerx.float32)
        array.require_grad(bp1)
    assert ('array([3.], shape=(1,), dtype=float32, device=\'native:0\', '
            'backprop_ids=[\'<expired>\'])' == str(array))
Beispiel #19
0
def test_as_grad_stopped_copy(shape, float_dtype):
    dtype = float_dtype

    def check(array_a, array_b):
        chainerx.testing.assert_array_equal_ex(
            array_a, array_b, strides_check=False)

        assert array_b.is_contiguous

        # Check memory addresses only if >0 bytes are allocated
        if array_a.size > 0:
            assert (array_a._debug_data_memory_address
                    != array_b._debug_data_memory_address)

    # Stop gradients on all graphs
    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2, \
            chainerx.backprop_scope('bp3') as bp3:

        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        b = a.as_grad_stopped(copy=True)

        check(a, b)
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)

        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)

        # Stop gradients on some graphs
        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        a.require_grad(bp3)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
        b = a.as_grad_stopped([bp1, bp2], copy=True)

        check(a, b)
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_grad_required(bp3)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)
        assert b.is_backprop_required(bp3)

        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
Beispiel #20
0
def test_backprop_multiple_graphs_reuse(method0, method1, method2):
    shape = (1, )
    dtype = chainerx.float32

    def fprop(x0, x1):
        return x0 * x1,

    with chainerx.backprop_scope('bp2') as backprop_id2, \
            chainerx.backprop_scope('bp1') as backprop_id1:
        xs = (
            chainerx.full(shape, 2, dtype).require_grad(backprop_id1),
            chainerx.full(shape, 5, dtype).require_grad(backprop_id2),
        )
        expected_gxs = (
            chainerx.full(shape, 5, dtype),
            None,
        )

        _check_backprop(method0,
                        fprop,
                        xs,
                        expected_gxs,
                        backprop_id=backprop_id1)

        x1, x2 = xs
        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id2) is None

        expected_gxs = (
            None,
            chainerx.full(shape, 2, dtype),
        )

        _check_backprop(method1,
                        fprop,
                        xs,
                        expected_gxs,
                        backprop_id=backprop_id2)

        x1.cleargrad(backprop_id1)
        x2.cleargrad(backprop_id2)

        x1.require_grad(backprop_id2)
        x2.require_grad(backprop_id1)

        expected_gxs = (
            chainerx.full(shape, 5, dtype),
            chainerx.full(shape, 2, dtype),
        )

        _check_backprop(method2,
                        fprop,
                        xs,
                        expected_gxs,
                        backprop_id=backprop_id2)

        assert x1.get_grad(backprop_id1) is None
        assert x2.get_grad(backprop_id1) is None
Beispiel #21
0
def test_as_grad_stopped_copy(shape, float_dtype):
    dtype = float_dtype

    def check(array_a, array_b):
        chainerx.testing.assert_array_equal_ex(
            array_a, array_b, strides_check=False)

        assert array_b.is_contiguous

        # Check memory addresses only if >0 bytes are allocated
        if array_a.size > 0:
            assert (array_a._debug_data_memory_address
                    != array_b._debug_data_memory_address)

    # Stop gradients on all graphs
    with chainerx.backprop_scope('bp1') as bp1, \
            chainerx.backprop_scope('bp2') as bp2, \
            chainerx.backprop_scope('bp3') as bp3:

        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        b = a.as_grad_stopped(copy=True)

        check(a, b)
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)

        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)

        # Stop gradients on some graphs
        a = array_utils.create_dummy_ndarray(chainerx, shape, dtype)
        a.require_grad(bp1)
        a.require_grad(bp2)
        a.require_grad(bp3)
        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
        b = a.as_grad_stopped([bp1, bp2], copy=True)

        check(a, b)
        assert not b.is_grad_required(bp1)
        assert not b.is_grad_required(bp2)
        assert not b.is_grad_required(bp3)
        assert not b.is_backprop_required(bp1)
        assert not b.is_backprop_required(bp2)
        assert b.is_backprop_required(bp3)

        assert a.is_grad_required(bp1)
        assert a.is_grad_required(bp2)
        assert a.is_grad_required(bp3)
        assert a.is_backprop_required(bp1)
        assert a.is_backprop_required(bp2)
        assert a.is_backprop_required(bp3)
Beispiel #22
0
def test_array_repr_expired_backprop_id():
    with chainerx.backprop_scope('bp1') as bp1:
        array = chainerx.array([3.0], chainerx.float32)
        array.require_grad(bp1)
    assert ('array([3.], shape=(1,), dtype=float32, device=\'native:0\', '
            'backprop_ids=[\'<expired>\'])' == str(array))