def testUseOutput(self): operators = [ CreateOperator('UseOutput', 'in', 'hidden'), CreateOperator('UseOutput', 'hidden', 'out'), CreateOperator('Direct', 'out', 'sink'), ] desired_grad_operators = [ CreateOperator('DirectGradient', 'sink_grad', 'out_grad'), CreateOperator('UseOutputGradient', ['out', 'out_grad'], 'hidden_grad'), CreateOperator('UseOutputGradient', ['hidden', 'hidden_grad'], 'in_grad'), ] gradients, _ = GradientRegistry.GetBackwardPass( operators, {'sink': 'sink_grad'}) self.assertEqual(gradients, desired_grad_operators)
def testStopGradientWithMultiUseOperators(self): operators = [ CreateOperator('Direct', 'in', 'hidden'), CreateOperator('Direct', 'hidden', 'hidden2'), CreateOperator('StopGradient', 'hidden', 'hidden3'), CreateOperator('Direct', ['hidden2', 'hidden3'], 'out'), ] desired_grad_operators = [ CreateOperator('DirectGradient', 'out_grad', ['hidden2_grad', 'hidden3_grad']), CreateOperator('DirectGradient', 'hidden2_grad', 'hidden_grad'), CreateOperator('DirectGradient', 'hidden_grad', 'in_grad'), ] gradients, grad_map = GradientRegistry.GetBackwardPass( operators, {'out': 'out_grad'}) self.assertEqual(gradients, desired_grad_operators) self.assertEqual( grad_map, {'out': 'out_grad', 'hidden2': 'hidden2_grad', 'hidden3': 'hidden3_grad', 'hidden': 'hidden_grad', 'in': 'in_grad'})
def testMultiUseInput(self, device_option): """Test gradient for the following case: in -> hidden1 in -> hidden2 hidden1, hidden2 -> out """ operators = [ CreateOperator('Direct', 'in', 'hidden1'), CreateOperator('Direct', 'in', 'hidden2'), CreateOperator('Direct', ['hidden1', 'hidden2'], 'out'), ] if device_option: for op in operators: op.device_option.CopyFrom(device_option) desired_grad_operators = [ CreateOperator( 'DirectGradient', 'out_grad', ['hidden1_grad', 'hidden2_grad'] ), CreateOperator( 'DirectGradient', 'hidden2_grad', '_in_grad_autosplit_0' ), CreateOperator( 'DirectGradient', 'hidden1_grad', '_in_grad_autosplit_1' ), CreateOperator( 'Sum', ['_in_grad_autosplit_0', '_in_grad_autosplit_1'], 'in_grad' ), ] if device_option: for op in desired_grad_operators: op.device_option.CopyFrom(device_option) gradients, _ = GradientRegistry.GetBackwardPass( operators, {"out": "out_grad"}) self.assertEqual(gradients, desired_grad_operators)
def AddUseInputGradient(op, g_output): return (CopyDeviceOption( CreateOperator('UseInputGradient', list(op.input) + NeedAll(op, g_output), GIS(op)), op), GIS(op))
def AddDirectGradient(op, g_output): return (CopyDeviceOption( CreateOperator('DirectGradient', NeedAll(op, g_output), GIS(op)), op), GIS(op))
def testMultiUseInputAndMultipleVersionsBig(self): """Test gradient for the following case: in -> in in -> hidden1, hidden2 hidden1, hidden2 -> in in -> hidden3, hidden4, hidden5 hidden3, hidden4, hidden5 -> out """ operators = [ CreateOperator('Direct', 'in', 'in'), CreateOperator('Direct', 'in', 'hidden1'), CreateOperator('Direct', 'in', 'hidden2'), CreateOperator('Direct', ['hidden1', 'hidden2'], 'in'), CreateOperator('Direct', 'in', 'hidden3'), CreateOperator('Direct', 'in', 'hidden4'), CreateOperator('Direct', 'in', 'hidden5'), CreateOperator('Direct', ['hidden3', 'hidden4', 'hidden5'], 'out'), ] desired_grad_operators = [ CreateOperator('DirectGradient', 'out_grad', ['hidden3_grad', 'hidden4_grad', 'hidden5_grad']), CreateOperator('DirectGradient', 'hidden5_grad', 'in_grad'), CreateOperator('DirectGradient', 'hidden4_grad', '_in_grad_autosplit_0'), CreateOperator('DirectGradient', 'hidden3_grad', '_in_grad_autosplit_1'), CreateOperator( 'Sum', ['in_grad', '_in_grad_autosplit_0', '_in_grad_autosplit_1'], 'in_grad'), CreateOperator('DirectGradient', 'in_grad', ['hidden1_grad', 'hidden2_grad']), CreateOperator('DirectGradient', 'hidden2_grad', 'in_grad'), CreateOperator('DirectGradient', 'hidden1_grad', '_in_grad_autosplit_0'), CreateOperator('Sum', ['in_grad', '_in_grad_autosplit_0'], 'in_grad'), CreateOperator('DirectGradient', 'in_grad', 'in_grad'), ] gradients, _ = GradientRegistry.GetBackwardPass( operators, {'out': 'out_grad'}) for s in gradients: print(str(s)) self.assertOperatorListEqual(gradients, desired_grad_operators)