Exemple #1
0
    def test_with_persistent_flag(self, seed):
        x = nn.Variable((2, 3), need_grad=True)

        inputs = (x,)

        def graph(x0):
            x1 = F.sin(x0).apply(recompute=True)
            # Set `recompute` and `persistent` flag at the same time
            x2 = F.sin(x1).apply(recompute=True, persistent=True)
            x3 = F.sin(x2).apply(recompute=True)
            y = F.sin(x3)
            return y

        y = graph(x)

        # Trace data clearing during forward propagation.
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y.forward(clear_no_need_grad=True)
        expected = [
            [False],  # x0: graph input
            [True],  # x1: Cleared because `recompute=True`
            [False],  # x2: Not cleared because `persistent=True`
            [True],  # x3: Cleared because `recompute=True`
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Check grad value
        self.check_recomputation(seed, graph, inputs)
Exemple #2
0
    def test_with_inplacing(self, seed):
        x = nn.Variable((2, 3), need_grad=True)

        inputs = (x,)

        def graph(x0):
            x1 = F.sin(x0).apply(recompute=True)
            # Set `recompute` flag to the inplaced variable.
            x2 = F.reshape(x1, (3, 2), inplace=True).apply(recompute=True)
            x3 = F.sin(x2).apply(recompute=True)
            y = F.sin(x3)
            return y

        y = graph(x)

        # Trace data clearing during forward propagation.
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y.forward(clear_no_need_grad=True)
        expected = [
            [False],  # x0: graph input
            [False],  # x1: Not cleared because inplaced data
            [False],  # x2: Not cleared because inplaced data
            [True],  # x3: Cleared because `recompute=True`
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Check grad value
        self.check_recomputation(seed, graph, inputs)
Exemple #3
0
    def test_clear_data_on_not_bwd_path(self):
        a0 = nn.Variable((2, 3), need_grad=True)
        a1 = F.identity(a0).apply(recompute=True)
        a2 = F.sin(a1).apply(recompute=True)

        # These three variables are not back-propagated.
        b0 = nn.Variable((2, 3), need_grad=False)
        b1 = F.identity(b0).apply(recompute=True)
        b2 = F.sin(b1).apply(recompute=True)

        c1 = F.add2(a2, b2).apply(recompute=True)
        c2 = F.sin(c1)

        # Forward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        c2.forward(clear_no_need_grad=True)
        # Data which will be recomputed must be cleared during forward propagation.
        expected = [
            [False],  # a0
            [True],  # a1
            [False],  # b0
            [True],  # b1
            [True, True],  # a2, b2
            [True],  # c1
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Backward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        c2.backward(clear_buffer=True)
        # b1 is not on backward path and must be cleared during recomputation.
        expected = [
            # Recomputation
            [False],  # a0
            [False],  # a1
            [False],  # b0
            [True],  # b1 (not on backward path) must be cleared
            [True, True],  # a2, b2
            [False],  # c1
            # Backward propagation
            [True, True],  # a2, b2
            [False],  # a1
            [False],  # a0
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
Exemple #4
0
    def test_clear_output_grad_inplace(self):
        x1 = nn.Variable([1], need_grad=True)

        xx1 = F.identity(x1)
        y1 = F.add_scalar(xx1, inplace=True)
        y2 = F.add_scalar(y1)

        answer_grad = []
        answer_grad.append([True])
        answer_grad.append([True])
        answer_grad.append([True])

        y2.forward(clear_no_need_grad=True)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y2.backward(clear_buffer=True)

        self.check_grad_cleared_flags(answer_grad)
Exemple #5
0
    def test_clear_input_data(self):
        x0 = nn.Variable((1, 1), need_grad=True)
        # `F.sin` input data is always needed for grad calculation
        x1 = F.sin(x0).apply(recompute=True)
        x2 = F.sin(x1).apply(recompute=False)
        x3 = F.sin(x2)

        answer = []
        answer.append([False])  # x0
        answer.append([True])  # x1
        answer.append([False])  # x2

        clear_called_flag_recorder.activate_clear_called_flag_recorder()

        x3.forward(clear_no_need_grad=True)
        self.check_input_data_clear_called_flags(answer)

        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
Exemple #6
0
    def test_clear_output_grad_prohibit_clear_input(self):
        x1 = nn.Variable([1], need_grad=True)

        xx1 = F.identity(x1)
        y1 = F.add_scalar(xx1)
        y2 = F.add_scalar(xx1)
        y3 = F.sink(y1, y2)

        answer_grad = []
        answer_grad.append([True])  # y3
        answer_grad.append([False])  # y2
        answer_grad.append([False])  # y1
        answer_grad.append([True])  # xx1

        y3.forward(clear_no_need_grad=True)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y3.backward(clear_buffer=True)

        self.check_grad_cleared_flags(answer_grad)
Exemple #7
0
    def test_clear_output_grad_argument(self, grad):
        x1 = nn.Variable([1], need_grad=True)

        xx1 = F.identity(x1)
        y1 = F.add_scalar(xx1)

        answer_grad = []
        if grad is None or isinstance(grad, nn.NdArray):
            answer_grad.append([False])  # y1
        else:
            answer_grad.append([True])  # y1
        answer_grad.append([True])  # xx1

        y1.forward(clear_no_need_grad=True)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y1.backward(clear_buffer=True, grad=grad)

        self.check_grad_cleared_flags(answer_grad)
        assert y1.grad.clear_called == False
Exemple #8
0
    def test_clear_output_grad_persistent(self):
        x1 = nn.Variable([1], need_grad=True)

        xx1 = F.identity(x1)
        y1 = F.add_scalar(xx1)
        y2 = F.add_scalar(y1)

        xx1.persistent = True
        y2.persistent = True

        answer_grad = []
        answer_grad.append([False])  # y2
        answer_grad.append([True])  # y1
        answer_grad.append([False])  # xx1

        y2.forward(clear_no_need_grad=True)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        y2.backward(clear_buffer=True)

        self.check_grad_cleared_flags(answer_grad)
Exemple #9
0
    def test_clear_no_need_grad_during_recomputation(self):
        x0 = nn.Variable((2, 3), need_grad=True)

        x1 = F.identity(x0).apply(recompute=True)
        # x2.data must be cleared just after recomputation because they are not need for backward propagation.
        x2 = F.sin(x1).apply(recompute=True)
        x3 = F.identity(x2).apply(recompute=True)
        x4 = F.sin(x3)

        # Forward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.forward(clear_no_need_grad=True)
        # All intermediate data must be cleared.
        expected = [
            [False],  # x0
            [True],  # x1
            [True],  # x2
            [True],  # x3
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

        # Backward
        clear_called_flag_recorder.activate_clear_called_flag_recorder()
        x4.backward(clear_buffer=True)
        expected = [
            # Recomputation
            [False],  # x0
            [False],  # x1
            [True],  # x2: not need for grad calculation
            # Backward propagation
            [False],  # x3
            [True],  # x2
            [False],  # x1
            [False],  # x0
        ]
        self.check_input_data_clear_called_flags(expected)
        clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
Exemple #10
0
 def teardown_method(self):
     clear_called_flag_recorder.deactivate_clear_called_flag_recorder()