def testUseOutput(self):
     operators = [
         CreateOperator('UseOutput', 'in', 'hidden'),
         CreateOperator('UseOutput', 'hidden', 'out'),
         CreateOperator('Direct', 'out', 'sink'),
     ]
     desired_grad_operators = [
         CreateOperator('DirectGradient', 'sink_grad', 'out_grad'),
         CreateOperator('UseOutputGradient', ['out', 'out_grad'],
                        'hidden_grad'),
         CreateOperator('UseOutputGradient', ['hidden', 'hidden_grad'],
                        'in_grad'),
     ]
     gradients, _ = GradientRegistry.GetBackwardPass(
         operators, {'sink': 'sink_grad'})
     self.assertEqual(gradients, desired_grad_operators)
示例#2
0
 def testStopGradientWithMultiUseOperators(self):
     operators = [
         CreateOperator('Direct', 'in', 'hidden'),
         CreateOperator('Direct', 'hidden', 'hidden2'),
         CreateOperator('StopGradient', 'hidden', 'hidden3'),
         CreateOperator('Direct', ['hidden2', 'hidden3'], 'out'),
     ]
     desired_grad_operators = [
         CreateOperator('DirectGradient', 'out_grad',
                        ['hidden2_grad', 'hidden3_grad']),
         CreateOperator('DirectGradient', 'hidden2_grad', 'hidden_grad'),
         CreateOperator('DirectGradient', 'hidden_grad', 'in_grad'),
     ]
     gradients, grad_map = GradientRegistry.GetBackwardPass(
         operators, {'out': 'out_grad'})
     self.assertEqual(gradients, desired_grad_operators)
     self.assertEqual(
         grad_map, {'out': 'out_grad', 'hidden2': 'hidden2_grad',
                    'hidden3': 'hidden3_grad', 'hidden': 'hidden_grad',
                    'in': 'in_grad'})
示例#3
0
    def testMultiUseInput(self, device_option):
        """Test gradient for the following case:

        in -> hidden1
        in -> hidden2
        hidden1, hidden2 -> out
        """
        operators = [
            CreateOperator('Direct', 'in', 'hidden1'),
            CreateOperator('Direct', 'in', 'hidden2'),
            CreateOperator('Direct', ['hidden1', 'hidden2'], 'out'),
        ]
        if device_option:
            for op in operators:
                op.device_option.CopyFrom(device_option)
        desired_grad_operators = [
            CreateOperator(
                'DirectGradient',
                'out_grad', ['hidden1_grad', 'hidden2_grad']
            ),
            CreateOperator(
                'DirectGradient',
                'hidden2_grad', '_in_grad_autosplit_0'
            ),
            CreateOperator(
                'DirectGradient',
                'hidden1_grad', '_in_grad_autosplit_1'
            ),
            CreateOperator(
                'Sum',
                ['_in_grad_autosplit_0', '_in_grad_autosplit_1'], 'in_grad'
            ),
        ]
        if device_option:
            for op in desired_grad_operators:
                op.device_option.CopyFrom(device_option)
        gradients, _ = GradientRegistry.GetBackwardPass(
            operators, {"out": "out_grad"})
        self.assertEqual(gradients, desired_grad_operators)
示例#4
0
def AddUseInputGradient(op, g_output):
    return (CopyDeviceOption(
        CreateOperator('UseInputGradient',
                       list(op.input) + NeedAll(op, g_output), GIS(op)),
        op), GIS(op))
示例#5
0
def AddDirectGradient(op, g_output):
    return (CopyDeviceOption(
        CreateOperator('DirectGradient', NeedAll(op, g_output), GIS(op)),
        op), GIS(op))
示例#6
0
    def testMultiUseInputAndMultipleVersionsBig(self):
        """Test gradient for the following case:

        in -> in
        in -> hidden1, hidden2
        hidden1, hidden2 -> in
        in -> hidden3, hidden4, hidden5
        hidden3, hidden4, hidden5 -> out
        """
        operators = [
            CreateOperator('Direct', 'in', 'in'),
            CreateOperator('Direct', 'in', 'hidden1'),
            CreateOperator('Direct', 'in', 'hidden2'),
            CreateOperator('Direct', ['hidden1', 'hidden2'], 'in'),
            CreateOperator('Direct', 'in', 'hidden3'),
            CreateOperator('Direct', 'in', 'hidden4'),
            CreateOperator('Direct', 'in', 'hidden5'),
            CreateOperator('Direct', ['hidden3', 'hidden4', 'hidden5'], 'out'),
        ]
        desired_grad_operators = [
            CreateOperator('DirectGradient', 'out_grad',
                           ['hidden3_grad', 'hidden4_grad', 'hidden5_grad']),
            CreateOperator('DirectGradient', 'hidden5_grad', 'in_grad'),
            CreateOperator('DirectGradient', 'hidden4_grad',
                           '_in_grad_autosplit_0'),
            CreateOperator('DirectGradient', 'hidden3_grad',
                           '_in_grad_autosplit_1'),
            CreateOperator(
                'Sum',
                ['in_grad', '_in_grad_autosplit_0', '_in_grad_autosplit_1'],
                'in_grad'),
            CreateOperator('DirectGradient', 'in_grad',
                           ['hidden1_grad', 'hidden2_grad']),
            CreateOperator('DirectGradient', 'hidden2_grad', 'in_grad'),
            CreateOperator('DirectGradient', 'hidden1_grad',
                           '_in_grad_autosplit_0'),
            CreateOperator('Sum', ['in_grad', '_in_grad_autosplit_0'],
                           'in_grad'),
            CreateOperator('DirectGradient', 'in_grad', 'in_grad'),
        ]
        gradients, _ = GradientRegistry.GetBackwardPass(
            operators, {'out': 'out_grad'})
        for s in gradients:
            print(str(s))
        self.assertOperatorListEqual(gradients, desired_grad_operators)