def test_sequential_with_statement(self, f1, f2): """ Test for sequential use of with statement. """ x = nn.Variable((2, 3)) assert x.recompute == False # First `with` block with nn.recompute(f1): y = F.relu(x) assert y.recompute == f1 y = F.sin(y) assert y.recompute == f1 assert y.recompute == f1 y = F.relu(y) assert y.recompute == False # Second `with` block with nn.recompute(f2): y = F.relu(x) assert y.recompute == f2 y = F.sin(y) assert y.recompute == f2 assert y.recompute == f2 y = F.relu(y) assert y.recompute == False
def test_with_statement_variable_creation(self, recompute_flag): """ Test for setting recompute flags with Python `with` statement. """ # Create a new Variable x1 = nn.Variable((2, 3)) assert x1.recompute == False with nn.recompute(recompute_flag): # Create Variable by `__cinit__()` y1 = nn.Variable((2, 3)) assert y1.recompute == recompute_flag # Create Variable by `create_from_cvariable()` y2 = x1.reshape((3, 2), unlink=True) assert y2.recompute == recompute_flag # Create Variable by `create_from_cg_variable()` y3 = F.relu(x1) assert y3.recompute == recompute_flag # Create Variable by `from_numpy_array()` data = np.array((2, 3)) y4 = nn.Variable.from_numpy_array(data) assert y4.recompute == recompute_flag # Create Variable by `get_unlinked_variable()` y5 = x1.get_unlinked_variable() assert y5.recompute == recompute_flag # Recompute flag for referenced Variable must not be overwritten. # More detail tests are performed by `test_nested_with_statement` y6 = x1 assert y6.recompute == False # Direct function connection y7 = F.relu(F.relu(x1)) # Create a new Variable after with statement x2 = nn.Variable((2, 3)) assert x2.recompute == False # Check recompute flag of forcibly got Pyhon Variable. assert y7.parent.inputs[0].recompute == recompute_flag # Check default recompute flag for nn.recompute() with nn.recompute(): x = nn.Variable((2, 3)) assert x.recompute == True
def test_nested_with_statement(self, f1, f2, f3): """ Test for nested Pyhon `with` statement of recomputation. """ x0 = nn.Variable((2, 3)) assert x0.recompute == False # Nest 1 with nn.recompute(f1): x1 = nn.Variable((2, 3)) x0_1 = x0 assert x1.recompute == f1 assert x0_1.recompute == False # Nest 2 with nn.recompute(f2): x2 = nn.Variable((2, 3)) x0_2 = x0 x1_2 = x1 assert x2.recompute == f2 assert x0_2.recompute == False assert x1_2.recompute == f1 # Nest 3 with nn.recompute(f3): x3 = nn.Variable((2, 3)) x0_3 = x0 x1_3 = x1 x2_3 = x2 assert x3.recompute == f3 assert x0_3.recompute == False assert x1_3.recompute == f1 assert x2_3.recompute == f2 x2 = nn.Variable((2, 3)) x0_2 = x0 x1_2 = x1 assert x2.recompute == f2 assert x0_2.recompute == False assert x1_2.recompute == f1 x1 = nn.Variable((2, 3)) x0_1 = x0 assert x1.recompute == f1 assert x0_1.recompute == False x0 = nn.Variable((2, 3)) assert x0.recompute == False