def testIsSubscribedIdentity(self): """Confirm subscribed identity ops are correctly detected.""" a = constant_op.constant(1) b = constant_op.constant(2) c = math_ops.add(a, b) idop = array_ops.identity(c) c_sub = subscribe.subscribe(c, []) self.assertFalse(subscribe._is_subscribed_identity(a)) self.assertFalse(subscribe._is_subscribed_identity(c)) self.assertFalse(subscribe._is_subscribed_identity(idop)) self.assertTrue(subscribe._is_subscribed_identity(c_sub))
def testResourceType(self): """Confirm that subscribe correctly handles tensors with 'resource' type.""" tensor_array = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name='test', size=3, infer_shape=False) writer = tensor_array.write(0, [[4.0, 5.0]]) reader = writer.read(0) shared = [] def sub(t): shared.append(t) return t # TensorArray's handle output tensor has a 'resource' type and cannot be # subscribed as it's not 'numpy compatible' (see dtypes.py). # Expect that the original tensor is returned when subscribing to it. tensor_array_sub = subscribe.subscribe( tensor_array.handle, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertIs(tensor_array_sub, tensor_array.handle) self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle)) with self.cached_session() as sess: self.evaluate([reader]) self.assertEqual(0, len(shared))
def testResourceType(self): """Confirm that subscribe correctly handles tensors with 'resource' type.""" tensor_array = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name='test', size=3, infer_shape=False) writer = tensor_array.write(0, [[4.0, 5.0]]) reader = writer.read(0) shared = [] def sub(t): shared.append(t) return t # TensorArray's handle output tensor has a 'resource' type and cannot be # subscribed as it's not 'numpy compatible' (see dtypes.py). # Expect that the original tensor is returned when subscribing to it. tensor_array_sub = subscribe.subscribe( tensor_array.handle, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertIs(tensor_array_sub, tensor_array.handle) self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle)) with self.cached_session() as sess: sess.run([reader]) self.assertEqual(0, len(shared))
def testSubscribeVariable(self): """Confirm that variables can be subscribed.""" v1 = variables.VariableV1(0.0) v2 = variables.VariableV1(4.0) add = math_ops.add(v1, v2) assign_v1 = v1.assign(3.0) shared = [] def sub(t): shared.append(t) return t v1_sub = subscribe.subscribe( v1, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertTrue(subscribe._is_subscribed_identity(v1_sub)) with self.cached_session() as sess: # Initialize the variables first. self.evaluate([v1.initializer]) self.evaluate([v2.initializer]) # Expect the side effects to be triggered when evaluating the add op as # it will read the value of the variable. self.evaluate([add]) self.assertEqual(1, len(shared)) # Expect the side effect not to be triggered when evaluating the assign # op as it will not access the 'read' output of the variable. self.evaluate([assign_v1]) self.assertEqual(1, len(shared)) self.evaluate([add]) self.assertEqual(2, len(shared)) # Make sure the values read from the variable match the expected ones. self.assertEqual([0.0, 3.0], shared)
def testSubscribeVariable(self): """Confirm that variables can be subscribed.""" v1 = variables.Variable(0.0) v2 = variables.Variable(4.0) add = math_ops.add(v1, v2) assign_v1 = v1.assign(3.0) shared = [] def sub(t): shared.append(t) return t v1_sub = subscribe.subscribe( v1, lambda t: script_ops.py_func(sub, [t], [t.dtype])) self.assertTrue(subscribe._is_subscribed_identity(v1_sub)) with self.cached_session() as sess: # Initialize the variables first. sess.run([v1.initializer]) sess.run([v2.initializer]) # Expect the side effects to be triggered when evaluating the add op as # it will read the value of the variable. sess.run([add]) self.assertEqual(1, len(shared)) # Expect the side effect not to be triggered when evaluating the assign # op as it will not access the 'read' output of the variable. sess.run([assign_v1]) self.assertEqual(1, len(shared)) sess.run([add]) self.assertEqual(2, len(shared)) # Make sure the values read from the variable match the expected ones. self.assertEqual([0.0, 3.0], shared)
def _ExpectSubscribedIdentities(self, container): """Convenience function to test a container of subscribed identities.""" self.assertTrue( all(subscribe._is_subscribed_identity(x) for x in container))