def test_clear_output_grad_argument(self, grad): x1 = nn.Variable([1], need_grad=True) xx1 = F.identity(x1) y1 = F.add_scalar(xx1) answer_grad = [] if grad is None or isinstance(grad, nn.NdArray): answer_grad.append([False]) # y1 else: answer_grad.append([True]) # y1 answer_grad.append([True]) # xx1 y1.forward(clear_no_need_grad=True) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() clear_called_flag_recorder.activate_clear_called_flag_recorder() y1.backward(clear_buffer=True, grad=grad) self.check_grad_cleared_flags(answer_grad) assert y1.grad.clear_called == False
def test_clear_no_need_grad_during_recomputation(self): x0 = nn.Variable((2, 3), need_grad=True) x1 = F.identity(x0).apply(recompute=True) # x2.data must be cleared just after recomputation because they are not need for backward propagation. x2 = F.sin(x1).apply(recompute=True) x3 = F.identity(x2).apply(recompute=True) x4 = F.sin(x3) # Forward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.forward(clear_no_need_grad=True) # All intermediate data must be cleared. expected = [ [False], # x0 [True], # x1 [True], # x2 [True], # x3 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder() # Backward clear_called_flag_recorder.activate_clear_called_flag_recorder() x4.backward(clear_buffer=True) expected = [ # Recomputation [False], # x0 [False], # x1 [True], # x2: not need for grad calculation # Backward propagation [False], # x3 [True], # x2 [False], # x1 [False], # x0 ] self.check_input_data_clear_called_flags(expected) clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
def setup_method(self): clear_called_flag_recorder.activate_clear_called_flag_recorder()