def test_with_persistent_flag(self, seed): x = nn.Variable((2, 3), need_grad=True) inputs = (x,) def graph(x0): x1 = F.sin(x0).apply(recompute=True) # Set `recompute` and `persistent` flag at the same time x2 = F.sin(x1).apply(recompute=True, persistent=True) x3 = F.sin(x2).apply(recompute=True) y = F.sin(x3) return y y = graph(x) # Trace data clearing during forward propagation. clear_called_flag_recorder.activate_clear_called_flag_recorder() y.forward(clear_no_need_grad=True) expected = [ [False], # x0: graph input [True], # x1: Cleared because `recompute=True` [False], # x2: Not cleared because `persistent=True` [True], # x3: Cleared because `recompute=True` ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Check grad value self.check_recomputation(seed, graph, inputs)
def test_with_inplacing(self, seed): x = nn.Variable((2, 3), need_grad=True) inputs = (x,) def graph(x0): x1 = F.sin(x0).apply(recompute=True) # Set `recompute` flag to the inplaced variable. x2 = F.reshape(x1, (3, 2), inplace=True).apply(recompute=True) x3 = F.sin(x2).apply(recompute=True) y = F.sin(x3) return y y = graph(x) # Trace data clearing during forward propagation. clear_called_flag_recorder.activate_clear_called_flag_recorder() y.forward(clear_no_need_grad=True) expected = [ [False], # x0: graph input [False], # x1: Not cleared because inplaced data [False], # x2: Not cleared because inplaced data [True], # x3: Cleared because `recompute=True` ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Check grad value self.check_recomputation(seed, graph, inputs)
def test_clear_data_on_not_bwd_path(self): a0 = nn.Variable((2, 3), need_grad=True) a1 = F.identity(a0).apply(recompute=True) a2 = F.sin(a1).apply(recompute=True) # These three variables are not back-propagated. b0 = nn.Variable((2, 3), need_grad=False) b1 = F.identity(b0).apply(recompute=True) b2 = F.sin(b1).apply(recompute=True) c1 = F.add2(a2, b2).apply(recompute=True) c2 = F.sin(c1) # Forward clear_called_flag_recorder.activate_clear_called_flag_recorder() c2.forward(clear_no_need_grad=True) # Data which will be recomputed must be cleared during forward propagation. expected = [ [False], # a0 [True], # a1 [False], # b0 [True], # b1 [True, True], # a2, b2 [True], # c1 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Backward clear_called_flag_recorder.activate_clear_called_flag_recorder() c2.backward(clear_buffer=True) # b1 is not on backward path and must be cleared during recomputation. expected = [ # Recomputation [False], # a0 [False], # a1 [False], # b0 [True], # b1 (not on backward path) must be cleared [True, True], # a2, b2 [False], # c1 # Backward propagation [True, True], # a2, b2 [False], # a1 [False], # a0 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def test_clear_output_grad_inplace(self): x1 = nn.Variable([1], need_grad=True) xx1 = F.identity(x1) y1 = F.add_scalar(xx1, inplace=True) y2 = F.add_scalar(y1) answer_grad = [] answer_grad.append([True]) answer_grad.append([True]) answer_grad.append([True]) y2.forward(clear_no_need_grad=True) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() clear_called_flag_recorder.activate_clear_called_flag_recorder() y2.backward(clear_buffer=True) self.check_grad_cleared_flags(answer_grad)
def test_clear_input_data(self): x0 = nn.Variable((1, 1), need_grad=True) # `F.sin` input data is always needed for grad calculation x1 = F.sin(x0).apply(recompute=True) x2 = F.sin(x1).apply(recompute=False) x3 = F.sin(x2) answer = [] answer.append([False]) # x0 answer.append([True]) # x1 answer.append([False]) # x2 clear_called_flag_recorder.activate_clear_called_flag_recorder() x3.forward(clear_no_need_grad=True) self.check_input_data_clear_called_flags(answer) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def test_clear_output_grad_prohibit_clear_input(self): x1 = nn.Variable([1], need_grad=True) xx1 = F.identity(x1) y1 = F.add_scalar(xx1) y2 = F.add_scalar(xx1) y3 = F.sink(y1, y2) answer_grad = [] answer_grad.append([True]) # y3 answer_grad.append([False]) # y2 answer_grad.append([False]) # y1 answer_grad.append([True]) # xx1 y3.forward(clear_no_need_grad=True) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() clear_called_flag_recorder.activate_clear_called_flag_recorder() y3.backward(clear_buffer=True) self.check_grad_cleared_flags(answer_grad)
def test_clear_output_grad_argument(self, grad): x1 = nn.Variable([1], need_grad=True) xx1 = F.identity(x1) y1 = F.add_scalar(xx1) answer_grad = [] if grad is None or isinstance(grad, nn.NdArray): answer_grad.append([False]) # y1 else: answer_grad.append([True]) # y1 answer_grad.append([True]) # xx1 y1.forward(clear_no_need_grad=True) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() clear_called_flag_recorder.activate_clear_called_flag_recorder() y1.backward(clear_buffer=True, grad=grad) self.check_grad_cleared_flags(answer_grad) assert y1.grad.clear_called == False
def test_clear_output_grad_persistent(self): x1 = nn.Variable([1], need_grad=True) xx1 = F.identity(x1) y1 = F.add_scalar(xx1) y2 = F.add_scalar(y1) xx1.persistent = True y2.persistent = True answer_grad = [] answer_grad.append([False]) # y2 answer_grad.append([True]) # y1 answer_grad.append([False]) # xx1 y2.forward(clear_no_need_grad=True) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() clear_called_flag_recorder.activate_clear_called_flag_recorder() y2.backward(clear_buffer=True) self.check_grad_cleared_flags(answer_grad)
def test_clear_no_need_grad_during_recomputation(self): x0 = nn.Variable((2, 3), need_grad=True) x1 = F.identity(x0).apply(recompute=True) # x2.data must be cleared just after recomputation because they are not need for backward propagation. x2 = F.sin(x1).apply(recompute=True) x3 = F.identity(x2).apply(recompute=True) x4 = F.sin(x3) # Forward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.forward(clear_no_need_grad=True) # All intermediate data must be cleared. expected = [ [False], # x0 [True], # x1 [True], # x2 [True], # x3 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Backward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.backward(clear_buffer=True) expected = [ # Recomputation [False], # x0 [False], # x1 [True], # x2: not need for grad calculation # Backward propagation [False], # x3 [True], # x2 [False], # x1 [False], # x0 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def teardown_method(self): clear_called_flag_recorder.deactivate_clear_called_flag_recorder()