コード例 #1
0
ファイル: subscribe_test.py プロジェクト: Wajih-O/tensorflow
  def testIsSubscribedIdentity(self):
    """Confirm subscribed identity ops are correctly detected."""
    a = constant_op.constant(1)
    b = constant_op.constant(2)
    c = math_ops.add(a, b)
    idop = array_ops.identity(c)
    c_sub = subscribe.subscribe(c, [])

    self.assertFalse(subscribe._is_subscribed_identity(a))
    self.assertFalse(subscribe._is_subscribed_identity(c))
    self.assertFalse(subscribe._is_subscribed_identity(idop))
    self.assertTrue(subscribe._is_subscribed_identity(c_sub))
コード例 #2
0
    def testIsSubscribedIdentity(self):
        """Confirm subscribed identity ops are correctly detected."""
        a = constant_op.constant(1)
        b = constant_op.constant(2)
        c = math_ops.add(a, b)
        idop = array_ops.identity(c)
        c_sub = subscribe.subscribe(c, [])

        self.assertFalse(subscribe._is_subscribed_identity(a))
        self.assertFalse(subscribe._is_subscribed_identity(c))
        self.assertFalse(subscribe._is_subscribed_identity(idop))
        self.assertTrue(subscribe._is_subscribed_identity(c_sub))
コード例 #3
0
ファイル: subscribe_test.py プロジェクト: Wajih-O/tensorflow
  def testResourceType(self):
    """Confirm that subscribe correctly handles tensors with 'resource' type."""
    tensor_array = tensor_array_ops.TensorArray(
        dtype=dtypes.float32,
        tensor_array_name='test',
        size=3,
        infer_shape=False)
    writer = tensor_array.write(0, [[4.0, 5.0]])
    reader = writer.read(0)

    shared = []

    def sub(t):
      shared.append(t)
      return t

    # TensorArray's handle output tensor has a 'resource' type and cannot be
    # subscribed as it's not 'numpy compatible' (see dtypes.py).
    # Expect that the original tensor is returned when subscribing to it.
    tensor_array_sub = subscribe.subscribe(
        tensor_array.handle, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
    self.assertIs(tensor_array_sub, tensor_array.handle)
    self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))

    with self.cached_session() as sess:
      self.evaluate([reader])
    self.assertEqual(0, len(shared))
コード例 #4
0
  def testResourceType(self):
    """Confirm that subscribe correctly handles tensors with 'resource' type."""
    tensor_array = tensor_array_ops.TensorArray(
        dtype=dtypes.float32,
        tensor_array_name='test',
        size=3,
        infer_shape=False)
    writer = tensor_array.write(0, [[4.0, 5.0]])
    reader = writer.read(0)

    shared = []

    def sub(t):
      shared.append(t)
      return t

    # TensorArray's handle output tensor has a 'resource' type and cannot be
    # subscribed as it's not 'numpy compatible' (see dtypes.py).
    # Expect that the original tensor is returned when subscribing to it.
    tensor_array_sub = subscribe.subscribe(
        tensor_array.handle, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
    self.assertIs(tensor_array_sub, tensor_array.handle)
    self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))

    with self.cached_session() as sess:
      sess.run([reader])
    self.assertEqual(0, len(shared))
コード例 #5
0
ファイル: subscribe_test.py プロジェクト: Wajih-O/tensorflow
  def testSubscribeVariable(self):
    """Confirm that variables can be subscribed."""
    v1 = variables.VariableV1(0.0)
    v2 = variables.VariableV1(4.0)
    add = math_ops.add(v1, v2)
    assign_v1 = v1.assign(3.0)

    shared = []

    def sub(t):
      shared.append(t)
      return t

    v1_sub = subscribe.subscribe(
        v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
    self.assertTrue(subscribe._is_subscribed_identity(v1_sub))

    with self.cached_session() as sess:
      # Initialize the variables first.
      self.evaluate([v1.initializer])
      self.evaluate([v2.initializer])

      # Expect the side effects to be triggered when evaluating the add op as
      # it will read the value of the variable.
      self.evaluate([add])
      self.assertEqual(1, len(shared))

      # Expect the side effect not to be triggered when evaluating the assign
      # op as it will not access the 'read' output of the variable.
      self.evaluate([assign_v1])
      self.assertEqual(1, len(shared))

      self.evaluate([add])
      self.assertEqual(2, len(shared))

      # Make sure the values read from the variable match the expected ones.
      self.assertEqual([0.0, 3.0], shared)
コード例 #6
0
  def testSubscribeVariable(self):
    """Confirm that variables can be subscribed."""
    v1 = variables.Variable(0.0)
    v2 = variables.Variable(4.0)
    add = math_ops.add(v1, v2)
    assign_v1 = v1.assign(3.0)

    shared = []

    def sub(t):
      shared.append(t)
      return t

    v1_sub = subscribe.subscribe(
        v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
    self.assertTrue(subscribe._is_subscribed_identity(v1_sub))

    with self.cached_session() as sess:
      # Initialize the variables first.
      sess.run([v1.initializer])
      sess.run([v2.initializer])

      # Expect the side effects to be triggered when evaluating the add op as
      # it will read the value of the variable.
      sess.run([add])
      self.assertEqual(1, len(shared))

      # Expect the side effect not to be triggered when evaluating the assign
      # op as it will not access the 'read' output of the variable.
      sess.run([assign_v1])
      self.assertEqual(1, len(shared))

      sess.run([add])
      self.assertEqual(2, len(shared))

      # Make sure the values read from the variable match the expected ones.
      self.assertEqual([0.0, 3.0], shared)
コード例 #7
0
ファイル: subscribe_test.py プロジェクト: Wajih-O/tensorflow
 def _ExpectSubscribedIdentities(self, container):
   """Convenience function to test a container of subscribed identities."""
   self.assertTrue(
       all(subscribe._is_subscribed_identity(x) for x in container))
コード例 #8
0
 def _ExpectSubscribedIdentities(self, container):
   """Convenience function to test a container of subscribed identities."""
   self.assertTrue(
       all(subscribe._is_subscribed_identity(x) for x in container))