Пример #1
0
    def test_sequential_with_statement(self, f1, f2):
        """
        Test for sequential use of with statement.
        """
        x = nn.Variable((2, 3))
        assert x.recompute == False

        # First `with` block
        with nn.recompute(f1):
            y = F.relu(x)
            assert y.recompute == f1
            y = F.sin(y)
            assert y.recompute == f1

        assert y.recompute == f1

        y = F.relu(y)
        assert y.recompute == False

        # Second `with` block
        with nn.recompute(f2):
            y = F.relu(x)
            assert y.recompute == f2
            y = F.sin(y)
            assert y.recompute == f2

        assert y.recompute == f2

        y = F.relu(y)
        assert y.recompute == False
Пример #2
0
    def test_with_statement_variable_creation(self, recompute_flag):
        """
        Test for setting recompute flags with Python `with` statement.
        """

        # Create a new Variable
        x1 = nn.Variable((2, 3))
        assert x1.recompute == False

        with nn.recompute(recompute_flag):
            # Create Variable by `__cinit__()`
            y1 = nn.Variable((2, 3))
            assert y1.recompute == recompute_flag

            # Create Variable by `create_from_cvariable()`
            y2 = x1.reshape((3, 2), unlink=True)
            assert y2.recompute == recompute_flag

            # Create Variable by `create_from_cg_variable()`
            y3 = F.relu(x1)
            assert y3.recompute == recompute_flag

            # Create Variable by `from_numpy_array()`
            data = np.array((2, 3))
            y4 = nn.Variable.from_numpy_array(data)
            assert y4.recompute == recompute_flag

            # Create Variable by `get_unlinked_variable()`
            y5 = x1.get_unlinked_variable()
            assert y5.recompute == recompute_flag

            # Recompute flag for referenced Variable must not be overwritten.
            # More detail tests are performed by `test_nested_with_statement`
            y6 = x1
            assert y6.recompute == False

            # Direct function connection
            y7 = F.relu(F.relu(x1))

        # Create a new Variable after with statement
        x2 = nn.Variable((2, 3))
        assert x2.recompute == False

        # Check recompute flag of forcibly got Pyhon Variable.
        assert y7.parent.inputs[0].recompute == recompute_flag

        # Check default recompute flag for nn.recompute()
        with nn.recompute():
            x = nn.Variable((2, 3))
            assert x.recompute == True
Пример #3
0
    def test_nested_with_statement(self, f1, f2, f3):
        """
        Test for nested Pyhon `with` statement of recomputation.
        """

        x0 = nn.Variable((2, 3))
        assert x0.recompute == False

        # Nest 1
        with nn.recompute(f1):
            x1 = nn.Variable((2, 3))
            x0_1 = x0
            assert x1.recompute == f1
            assert x0_1.recompute == False

            # Nest 2
            with nn.recompute(f2):
                x2 = nn.Variable((2, 3))
                x0_2 = x0
                x1_2 = x1
                assert x2.recompute == f2
                assert x0_2.recompute == False
                assert x1_2.recompute == f1

                # Nest 3
                with nn.recompute(f3):
                    x3 = nn.Variable((2, 3))
                    x0_3 = x0
                    x1_3 = x1
                    x2_3 = x2
                    assert x3.recompute == f3
                    assert x0_3.recompute == False
                    assert x1_3.recompute == f1
                    assert x2_3.recompute == f2

                x2 = nn.Variable((2, 3))
                x0_2 = x0
                x1_2 = x1
                assert x2.recompute == f2
                assert x0_2.recompute == False
                assert x1_2.recompute == f1

            x1 = nn.Variable((2, 3))
            x0_1 = x0
            assert x1.recompute == f1
            assert x0_1.recompute == False

        x0 = nn.Variable((2, 3))
        assert x0.recompute == False