def test_repeat_fn_exterior_value_static(self): add_op = LambdaOp(inputs='x', outputs=('x', 'y'), fn=lambda w: (w + 1, w * w), mode='eval') repeat_op = Repeat(add_op, repeat=lambda y, z: y + z < 25) repeat_op.build('tf') with self.subTest('Check op inputs'): self.assertListEqual(repeat_op.inputs, ['x', 'z']) with self.subTest('Check op outputs'): self.assertListEqual(repeat_op.outputs, ['x', 'y']) with self.subTest('Check op mode'): self.assertSetEqual(repeat_op.mode, {'eval'}) x = [tf.ones([1]), 10 + tf.ones([1])] output = tf.function(lambda y: repeat_op.forward(data=y, state={ "deferred": {}, "mode": "eval" })) output(x) output = output(x) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output value (x)'): self.assertEqual(5, output[0]) with self.subTest('Check output value (y)'): self.assertEqual(16, output[1])
def test_repeat_fn_exterior_value_tf(self): add_op = LambdaOp(inputs='x', outputs=('x', 'y'), fn=lambda x: (x + 1, x * x), mode='eval') repeat_op = Repeat(add_op, repeat=lambda y, z: y + z < 25) repeat_op.build('tf') with self.subTest('Check op inputs'): self.assertListEqual(repeat_op.inputs, ['x', 'z']) with self.subTest('Check op outputs'): self.assertListEqual(repeat_op.outputs, ['x', 'y']) with self.subTest('Check op mode'): self.assertSetEqual(repeat_op.mode, {'eval'}) with tf.GradientTape(persistent=True) as tape: output = repeat_op.forward(data=[tf.ones([1]), 10 + tf.ones([1])], state={ "deferred": {}, "mode": "eval", "tape": tape }) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output value (x)'): self.assertEqual(5, output[0]) with self.subTest('Check output value (y)'): self.assertEqual(16, output[1])
def test_single_repeat_fn_interior_value_static(self): add_op = LambdaOp(inputs='x', outputs=('x', 'y'), fn=lambda z: (z + 1, z * z), mode='eval') repeat_op = Repeat(add_op, repeat=lambda y: y < 1) repeat_op.build('tf') with self.subTest('Check op inputs'): self.assertListEqual(repeat_op.inputs, ['x']) with self.subTest('Check op outputs'): self.assertListEqual(repeat_op.outputs, ['x', 'y']) with self.subTest('Check op mode'): self.assertSetEqual(repeat_op.mode, {'eval'}) x = [tf.ones([1])] output = tf.function(lambda y: repeat_op.forward(data=y, state={ "deferred": {}, "mode": "eval" })) output(x) # build the graph output = output(x) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output value (x)'): self.assertEqual(2, output[0]) with self.subTest('Check output value (y)'): self.assertEqual(1, output[1])
def test_multi_repeat_fn_interior_value_tf(self): add_op = LambdaOp(inputs='x', outputs=('x', 'y'), fn=lambda x: (x + 1, x * x), mode='eval') repeat_op = Repeat(add_op, repeat=lambda y: y < 25) repeat_op.build('tf') with self.subTest('Check op inputs'): self.assertListEqual(repeat_op.inputs, ['x']) with self.subTest('Check op outputs'): self.assertListEqual(repeat_op.outputs, ['x', 'y']) with self.subTest('Check op mode'): self.assertSetEqual(repeat_op.mode, {'eval'}) output = repeat_op.forward(data=[tf.ones([1])], state={"deferred": {}, "mode": "eval"}) with self.subTest('Check output type'): self.assertEqual(type(output), list) with self.subTest('Check output value (x)'): self.assertEqual(6, output[0]) with self.subTest('Check output value (y)'): self.assertEqual(25, output[1])