Exemplo n.º 1
0
    def test_clear_output_grad_argument(self, grad):
        x1 = nn.Variable([1], need_grad=True)

        xx1 = F.identity(x1)
        y1 = F.add_scalar(xx1)

        answer_grad = []
        if grad is None or isinstance(grad, nn.NdArray):
            answer_grad.append([False])  # y1
        else:
            answer_grad.append([True])  # y1
        answer_grad.append([True])  # xx1

        y1.forward(clear_no_need_grad=True)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y1.backward(clear_buffer=True, grad=grad)

        self.check_grad_cleared_flags(answer_grad)
        assert y1.grad.clear_called == False
Exemplo n.º 2
0
    def test_clear_no_need_grad_during_recomputation(self):
        x0 = nn.Variable((2, 3), need_grad=True)

        x1 = F.identity(x0).apply(recompute=True)
        # x2.data must be cleared just after recomputation because they are not need for backward propagation.
        x2 = F.sin(x1).apply(recompute=True)
        x3 = F.identity(x2).apply(recompute=True)
        x4 = F.sin(x3)

        # Forward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.forward(clear_no_need_grad=True)
        # All intermediate data must be cleared.
        expected = [
            [False],  # x0
            [True],  # x1
            [True],  # x2
            [True],  # x3
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Backward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.backward(clear_buffer=True)
        expected = [
            # Recomputation
            [False],  # x0
            [False],  # x1
            [True],  # x2: not need for grad calculation
            # Backward propagation
            [False],  # x3
            [True],  # x2
            [False],  # x1
            [False],  # x0
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
Exemplo n.º 3
0
 def setup_method(self):
     clear_called_flag_recorder.activate_clear_called_flag_recorder()