示例#1
0
    def testFeedTwoHandlesDirectly(self):
        with self.cached_session() as sess:
            a = constant_op.constant(10.0)
            b = constant_op.constant(5.0)
            c = math_ops.multiply(a, b)
            d = math_ops.div(a, b)
            e = math_ops.subtract(c, d)

            h_c = sess.run(session_ops.get_session_handle(c))
            h_d = sess.run(session_ops.get_session_handle(d))

            self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d}))
            self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
示例#2
0
  def testFeedTwoHandlesDirectly(self):
    with self.test_session() as sess:
      a = constant_op.constant(10.0)
      b = constant_op.constant(5.0)
      c = math_ops.multiply(a, b)
      d = math_ops.div(a, b)
      e = math_ops.subtract(c, d)

      h_c = sess.run(session_ops.get_session_handle(c))
      h_d = sess.run(session_ops.get_session_handle(d))

      self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d}))
      self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
示例#3
0
    def testHandleForLoop(self):
        with self.cached_session() as sess:
            # Initialize a handle.
            a = constant_op.constant(0)
            h = session_ops.get_session_handle(a)
            h = self.evaluate(h)

            # Do some computation.
            f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
            # Must define the loop body outside the loop.
            h_x = session_ops.get_session_handle(math_ops.add(x, 1))
            for _ in range(100):
                # This exercises garbage collection.
                h = sess.run(h_x, feed_dict={f: h.handle})

            self.assertEqual(100, h.eval())
示例#4
0
  def testDirectHandleFeedOverlappingWithFetches(self):
    with self.cached_session() as sess:
      a = constant_op.constant(10.0)
      b = constant_op.constant(5.0)
      c = math_ops.multiply(a, b)
      h_c = sess.run(session_ops.get_session_handle(c))
      d = array_ops.identity(c)

      c_val = sess.run(c, feed_dict={c: h_c})
      self.assertAllClose(50.0, c_val)

      d_val = sess.run(d, feed_dict={c: h_c})
      self.assertAllClose(50.0, d_val)

      c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: 60.0})
      self.assertAllClose(50.0, c_val)
      self.assertAllClose(60.0, d_val)

      c_val, d_val = sess.run([c, d], feed_dict={c: 60.0, d: h_c})
      self.assertAllClose(60.0, c_val)
      self.assertAllClose(50.0, d_val)

      c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: h_c})
      self.assertAllClose(50.0, c_val)
      self.assertAllClose(50.0, d_val)
示例#5
0
  def numpy_to_handle(self, array):
    """Upload numpy array into TensorFlow runtime.

    Args:
      array: numpy array to convert to TensorHandle

    Returns:
      TensorHandle corresponding to given numpy array.
    """

    tf_dtype = dtypes.as_dtype(array.dtype)
    current_device = get_current_device_string(self.g)
    current_device_sanitized = current_device.replace(":", "")
    key = ("numpy_to_handle", tf_dtype.name, current_device)

    if key in self.op_cache:
      holder, handle_op = self.op_cache[key]
    else:
      if self.PRINT_CACHE_MISSES:
        print("Imperative cache miss for %s"%(str(key)))

      op_prefix = "numpy_to_handle.%s.%s" % (tf_dtype.name,
                                             current_device_sanitized)
      with self.g.as_default():
        holder = array_ops.placeholder(dtype=array.dtype,
                                       name=op_prefix+".holder")
        handle_op = session_ops.get_session_handle(holder,
                                                   name=op_prefix+".handle")
      self.op_cache[key] = (holder, handle_op)

    handle = self.run(handle_op, feed_dict={holder: array})
    return handle
示例#6
0
    def testDirectHandleFeedOverlappingWithFetches(self):
        with self.cached_session() as sess:
            a = constant_op.constant(10.0)
            b = constant_op.constant(5.0)
            c = math_ops.multiply(a, b)
            h_c = sess.run(session_ops.get_session_handle(c))
            d = array_ops.identity(c)

            c_val = sess.run(c, feed_dict={c: h_c})
            self.assertAllClose(50.0, c_val)

            d_val = sess.run(d, feed_dict={c: h_c})
            self.assertAllClose(50.0, d_val)

            c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: 60.0})
            self.assertAllClose(50.0, c_val)
            self.assertAllClose(60.0, d_val)

            c_val, d_val = sess.run([c, d], feed_dict={c: 60.0, d: h_c})
            self.assertAllClose(60.0, c_val)
            self.assertAllClose(50.0, d_val)

            c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: h_c})
            self.assertAllClose(50.0, c_val)
            self.assertAllClose(50.0, d_val)
示例#7
0
  def testHandleForLoop(self):
    with self.test_session() as sess:
      # Initialize a handle.
      a = constant_op.constant(0)
      h = session_ops.get_session_handle(a)
      h = sess.run(h)

      # Do some computation.
      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
      # Must define the loop body outside the loop.
      h_x = session_ops.get_session_handle(math_ops.add(x, 1))
      for _ in range(100):
        # This exercises garbage collection.
        h = sess.run(h_x, feed_dict={f: h.handle})

      self.assertEqual(100, h.eval())
示例#8
0
 def testHandleDelete(self):
     with self.cached_session() as sess:
         # Return a handle.
         a = constant_op.constant(10)
         b = constant_op.constant(5)
         c = math_ops.multiply(a, b)
         h = session_ops.get_session_handle(c)
         sess.run(h).delete()
示例#9
0
  def testMultiDevices(self):
    with self.cached_session() as sess:
      with ops.device(test.gpu_device_name()):
        a = constant_op.constant(1.0)
        a_handle = self.evaluate(session_ops.get_session_handle(a))
      with ops.device("/cpu:0"):
        b = constant_op.constant(2.0)
        b_handle = self.evaluate(session_ops.get_session_handle(b))

      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)
      c = math_ops.add(a_t, b_t)
      c_handle = sess.run(
          session_ops.get_session_handle(c),
          feed_dict={a_p: a_handle.handle,
                     b_p: b_handle.handle})
      self.assertEqual(3.0, c_handle.eval())
示例#10
0
 def testHandleDelete(self):
   with self.test_session() as sess:
     # Return a handle.
     a = constant_op.constant(10)
     b = constant_op.constant(5)
     c = math_ops.multiply(a, b)
     h = session_ops.get_session_handle(c)
     sess.run(h).delete()
示例#11
0
  def tensor_to_itensor(self, tensor):

    op_prefix = "tensor_to_itensor"
    with self.g.as_default():
      handle_op = session_ops.get_session_handle(tensor,
                                                 name=op_prefix+".handle")
    handle = self.run(handle_op)
    return ITensor(self, handle)
示例#12
0
  def testMultiDevices(self):
    with self.test_session() as sess:
      with ops.device(test.gpu_device_name()):
        a = constant_op.constant(1.0)
        a_handle = sess.run(session_ops.get_session_handle(a))
      with ops.device("/cpu:0"):
        b = constant_op.constant(2.0)
        b_handle = sess.run(session_ops.get_session_handle(b))

      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)
      c = math_ops.add(a_t, b_t)
      c_handle = sess.run(
          session_ops.get_session_handle(c),
          feed_dict={a_p: a_handle.handle,
                     b_p: b_handle.handle})
      self.assertEqual(3.0, c_handle.eval())
示例#13
0
  def testHandleWhileLoop(self):
    with self.test_session() as sess:
      # Initialize a handle.
      a = constant_op.constant(0)
      h = session_ops.get_session_handle(a)
      h = sess.run(h)

      # Do some computation.
      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
      b = constant_op.constant(100)
      p = math_ops.less(x, b)
      # Must define the loop body outside the loop.
      h_x = session_ops.get_session_handle(math_ops.add(x, 1))
      while True:
        rp, h = sess.run([p, h_x], feed_dict={f: h.handle})
        if not rp:
          break

      self.assertEqual(101, h.eval())
示例#14
0
  def make_function(self, inputs, outputs, name=""):
    """Create callable that accept argument ITensors in the same order as
    inputs argument, and produces tuple of outputs which are ITensors
    corresponding to outputs.

    Example usage:
    x0 = env.tf.ones()       # create ITensor
    x = env.make_input(x0) # create Tensor
    y = env.make_input(x0) # create Tensor
    z1 = tf.add(x, y)         # operate on Tensors
    z2 = tf.sub(x, y)         # operate on Tensors
    f = env.make_function(inputs=[x, y], outputs=[z1, z2])

    print(f(x0, x0*5))       # feed ITensors, get result back as ITensors
    """

    input_holders = []
    for input_ in inputs:
      input_holders.append(self.input_dict[input_])

    output_handle_ops = []
    if is_list_or_tuple(outputs):
      for (i,tensor) in enumerate(outputs):
        op_name = "custom_function_%s.output.%s"%(name, i)
        output_handle_ops.append(session_ops.get_session_handle(tensor,
                                                                op_name))
    # special-case single output
    else:
      op_name = "custom_function_%s.output"%(name)
      output_handle_ops = session_ops.get_session_handle(outputs, op_name)

    def func(*args):
      feed_dict = {}
      for (i, arg) in enumerate(args):
        feed_dict[input_holders[i]] = arg.tf_handle

      tensor_handles = self.sess.run(output_handle_ops, feed_dict=feed_dict)
      if is_list_or_tuple(tensor_handles):
        return [ITensor(self, t) for t in tensor_handles]
      else:
        return ITensor(self, tensor_handles)
      
    return func
示例#15
0
  def testFeedOneHandleDirectly(self):
    with self.test_session() as sess:
      a = constant_op.constant(10.0)
      b = constant_op.constant(5.0)
      c = math_ops.multiply(a, b)
      d = math_ops.multiply(c, c)

      h_c = sess.run(session_ops.get_session_handle(c))

      self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
示例#16
0
  def testHandlePlacement(self):
    with self.cached_session() as sess:
      a = constant_op.constant(1.0)
      a_handle_op = session_ops.get_session_handle(a)
      b = constant_op.constant(2.0)
      b_handle_op = session_ops.get_session_handle(b)

      a_handle = self.evaluate(a_handle_op)
      b_handle = self.evaluate(b_handle_op)

      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)

      c = math_ops.add(a_t, b_t)
      c_handle = sess.run(
          session_ops.get_session_handle(c),
          feed_dict={a_p: a_handle.handle,
                     b_p: b_handle.handle})
      self.assertEqual(3.0, c_handle.eval())
示例#17
0
    def testHandleWhileLoop(self):
        with self.cached_session() as sess:
            # Initialize a handle.
            a = constant_op.constant(0)
            h = session_ops.get_session_handle(a)
            h = self.evaluate(h)

            # Do some computation.
            f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
            b = constant_op.constant(100)
            p = math_ops.less(x, b)
            # Must define the loop body outside the loop.
            h_x = session_ops.get_session_handle(math_ops.add(x, 1))
            while True:
                rp, h = sess.run([p, h_x], feed_dict={f: h.handle})
                if not rp:
                    break

            self.assertEqual(101, h.eval())
示例#18
0
  def testHandlePlacement(self):
    with self.test_session() as sess:
      a = constant_op.constant(1.0)
      a_handle_op = session_ops.get_session_handle(a)
      b = constant_op.constant(2.0)
      b_handle_op = session_ops.get_session_handle(b)

      a_handle = sess.run(a_handle_op)
      b_handle = sess.run(b_handle_op)

      a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
      b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)

      c = math_ops.add(a_t, b_t)
      c_handle = sess.run(
          session_ops.get_session_handle(c),
          feed_dict={a_p: a_handle.handle,
                     b_p: b_handle.handle})
      self.assertEqual(3.0, c_handle.eval())
示例#19
0
    def testFeedOneHandleDirectly(self):
        with self.cached_session() as sess:
            a = constant_op.constant(10.0)
            b = constant_op.constant(5.0)
            c = math_ops.multiply(a, b)
            d = math_ops.multiply(c, c)

            h_c = sess.run(session_ops.get_session_handle(c))

            self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
示例#20
0
    def testHandleMover(self):
        with self.cached_session() as sess:
            # Return a handle.
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            h = self.evaluate(h)

            # Feed a tensor handle.
            f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
            y = math_ops.multiply(x, 10)
            self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))

            # Feed another tensor handle.
            with ops.device(test.gpu_device_name()):
                a = constant_op.constant(10)
                h = session_ops.get_session_handle(a)
                h = self.evaluate(h)
                self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
示例#21
0
    def testHandleEval(self):
        with self.cached_session() as sess:
            # Return a handle.
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            h = self.evaluate(h)

            # Get the tensor from its handle.
            self.assertEqual(50, h.eval())
示例#22
0
  def testHandleMover(self):
    with self.test_session() as sess:
      # Return a handle.
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      h = sess.run(h)

      # Feed a tensor handle.
      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
      y = math_ops.multiply(x, 10)
      self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))

      # Feed another tensor handle.
      with ops.device(test.gpu_device_name()):
        a = constant_op.constant(10)
        h = session_ops.get_session_handle(a)
        h = sess.run(h)
        self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
示例#23
0
  def testHandleEval(self):
    with self.test_session() as sess:
      # Return a handle.
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      h = sess.run(h)

      # Get the tensor from its handle.
      self.assertEqual(50, h.eval())
示例#24
0
    def testHandleAndValue(self):
        with self.cached_session() as sess:
            # Return a handle and a value.
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            v = math_ops.multiply(a, c)
            h, v = sess.run([h, v])

            self.assertEqual(50, h.eval())
            self.assertEqual(500, v)
示例#25
0
  def testHandleAndValue(self):
    with self.test_session() as sess:
      # Return a handle and a value.
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      v = math_ops.multiply(a, c)
      h, v = sess.run([h, v])

      self.assertEqual(50, h.eval())
      self.assertEqual(500, v)
示例#26
0
  def testHandleBasic(self):
    with self.test_session() as sess:
      # Return a handle.
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      h = sess.run(h)

      # Feed a tensor handle.
      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
      y = math_ops.multiply(x, 10)
      self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
示例#27
0
  def testFeedHandleToVariableDirectly(self):
    with self.test_session() as sess:
      a = variables.Variable(12.0)
      inc_a = state_ops.assign_add(a, 2.0)
      b = math_ops.add(a, 5.0)
      sess.run(a.initializer)

      h_a_read = sess.run(session_ops.get_session_handle(a.read_value()))
      self.assertAllClose(12.0, sess.run(a))

      self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read}))
      sess.run(inc_a)
      self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read}))
示例#28
0
    def testHandleBasic(self):
        with self.cached_session() as sess:
            # Return a handle.
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            h = self.evaluate(h)

            # Feed a tensor handle.
            f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
            y = math_ops.multiply(x, 10)
            self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
示例#29
0
  def testHandleDeleteRaw(self):
    with self.test_session() as sess:
      # Return a handle.
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      h = sess.run(h)

      # Delete using a raw tensor handle.
      raw_h = h.get_raw_handle()
      f, x = session_ops.delete_session_tensor(raw_h)
      sess.run(x, feed_dict={f: raw_h})
示例#30
0
    def testHandleDeleteRaw(self):
        with self.cached_session() as sess:
            # Return a handle.
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            h = self.evaluate(h)

            # Delete using a raw tensor handle.
            raw_h = h.get_raw_handle()
            f, x = session_ops.delete_session_tensor(raw_h)
            sess.run(x, feed_dict={f: raw_h})
示例#31
0
    def testFeedHandleToVariableDirectly(self):
        with self.cached_session() as sess:
            a = variables.Variable(12.0)
            inc_a = state_ops.assign_add(a, 2.0)
            b = math_ops.add(a, 5.0)
            self.evaluate(a.initializer)

            h_a_read = sess.run(session_ops.get_session_handle(a.read_value()))
            self.assertAllClose(12.0, self.evaluate(a))

            self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read}))
            sess.run(inc_a)
            self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read}))
示例#32
0
  def testHandleGC(self):
    with self.cached_session() as sess:
      # initial values live on CPU
      with ops.device("/cpu:0"):
        one = constant_op.constant(1, dtype=dtypes.float32)
        one_handle = self.evaluate(session_ops.get_session_handle(one))
        x_handle = self.evaluate(session_ops.get_session_handle(one))

      # addition lives on GPU
      with ops.device(test.gpu_device_name()):
        add_h1, add_t1 = session_ops.get_session_tensor(one_handle.handle,
                                                        dtypes.float32)
        add_h2, add_t2 = session_ops.get_session_tensor(x_handle.handle,
                                                        dtypes.float32)
        add_op = math_ops.add(add_t1, add_t2)
        add_output = session_ops.get_session_handle(add_op)

      # add 1 to tensor 20 times
      for _ in range(20):
        x_handle = sess.run(
            add_output,
            feed_dict={add_h1: one_handle.handle,
                       add_h2: x_handle.handle})
示例#33
0
  def testHandleGC(self):
    with self.test_session() as sess:
      # initial values live on CPU
      with ops.device("/cpu:0"):
        one = constant_op.constant(1, dtype=dtypes.float32)
        one_handle = sess.run(session_ops.get_session_handle(one))
        x_handle = sess.run(session_ops.get_session_handle(one))

      # addition lives on GPU
      with ops.device(test.gpu_device_name()):
        add_h1, add_t1 = session_ops.get_session_tensor(one_handle.handle,
                                                        dtypes.float32)
        add_h2, add_t2 = session_ops.get_session_tensor(x_handle.handle,
                                                        dtypes.float32)
        add_op = math_ops.add(add_t1, add_t2)
        add_output = session_ops.get_session_handle(add_op)

      # add 1 to tensor 20 times
      for _ in range(20):
        x_handle = sess.run(
            add_output,
            feed_dict={add_h1: one_handle.handle,
                       add_h2: x_handle.handle})
示例#34
0
  def testHandleCond(self):
    with self.test_session() as sess:
      # Return a handle and a value
      a = constant_op.constant(10)
      b = constant_op.constant(5)
      p = math_ops.less(a, b)
      c = math_ops.multiply(a, b)
      h = session_ops.get_session_handle(c)
      p, h = sess.run([p, h])

      # Run by feeding a tensor handle.
      f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
      if p:
        y = math_ops.multiply(x, 10)
      else:
        y = math_ops.multiply(x, 100)
      result = sess.run(y, feed_dict={f: h.handle})

      self.assertEqual(5000, result)
示例#35
0
    def testHandleCond(self):
        with self.cached_session() as sess:
            # Return a handle and a value
            a = constant_op.constant(10)
            b = constant_op.constant(5)
            p = math_ops.less(a, b)
            c = math_ops.multiply(a, b)
            h = session_ops.get_session_handle(c)
            p, h = sess.run([p, h])

            # Run by feeding a tensor handle.
            f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
            if p:
                y = math_ops.multiply(x, 10)
            else:
                y = math_ops.multiply(x, 100)
            result = sess.run(y, feed_dict={f: h.handle})

            self.assertEqual(5000, result)
示例#36
0
  def sum1(self, input_itensor):
    """Create a specialized op that sums over 1 dimensional vector.
    This avoids having to create Rank/Range ops that initialize indices
    in the default tf.reduce_sum."""

    op_type_name = "sum1"
    tf_dtype = input_itensor.dtype
    current_device = get_current_device_string(self.g)
    current_device_sanitized = current_device.replace(":", "")
    key = (op_type_name, tf_dtype.name, current_device_sanitized)

    if key in self.op_cache:
      if self.PRINT_CACHE_HITS:
        print("Imperative cache hit for %s"%(str(key)))
      op = self.op_cache[key]
    else:
      if self.PRINT_CACHE_MISSES:
        print("Imperative cache miss for %s"%(str(key)))
      with self.g.as_default():
        op_prefix = op_type_name + "." + tf_dtype.name
        holder, tensor = session_ops.get_session_tensor(
            input_itensor.tf_handle, input_itensor.dtype, name=op_prefix+".0")
        input_holders = {"input": holder}
        reduction_indices = constant_op.constant([0], dtype=dtypes.int32,
                                                 name=op_prefix+".1")
        output = gen_math_ops._sum(input=tensor,
                                   reduction_indices=reduction_indices,
                                   keep_dims=False, name=op_prefix+".op")
        op_prefix = op_prefix+".out"
        output_handle = session_ops.get_session_handle(output,
                                                       op_prefix+".handle")

      op = Op(self, input_holders, output_handle)
      self.cache_add(key, op)

    return op(input=input_itensor)
示例#37
0
  def apply_op(self, op_type_name, name=None, **keywords):
    """Wrapper for op_def_library apply_op with caching.

    This method aims to be semantically identical to "apply_op" of OpDefLibrary
    but work with ITensor instead of Tensor objects.

    Brief overview

    1. Extract input arguments from keywords and convert Python types into
    corresponding itensors using type constraints of the corresponding OpDef
    2. Figure out OpDef that would've been constructed for this op if original
       op_def_library were called by looking at inferred/explicit attributes,
       argument device locations, and current device constext
    3. Fetch corresponding OpDef from cache if such OpDef was already
       constructed
    4. Otherwise construct OpDef and wrap it in Op object
    5. Save Op object in cache, and run it to produce itensor result
    """

    op_def = self._lookup_opdef_for_type(op_type_name)

    # names of input arguments, ie "x", "y" for Add op
    input_names = [arg.name for arg in op_def.input_arg]

    # convert any python inputs in keywords into ITensors
    convert_to_itensors_with_type_inference(op_def, keywords,
                                            self.env.numpy_to_itensor)

    current_device = get_current_device_string(self.env.g)
    key = create_opdef_key(op_def, keywords, current_device)
    op = self.env.cache_lookup(key)

    # Found op in cache, run it in return the results
    if op:
      return op(**keywords)

    # Couldn't find op in graph cache, create it and add to cache
    if self.env.PRINT_CACHE_MISSES:
      print("Imperative cache miss for %s" %(str(key)))


    # Graph construction overview:
    # The new operation must reproduce old operation, except that inputs
    # and outputs must be string tensor handles instead of Tensors
    # 1. Convert input string tensor handles into Tensors
    # 2. Run the op
    # 3. Convert output tensors into string tensor handles

    # prefix to use for node names in graph, like "Add.float32"
    if len(input_names) > 0 and isinstance(keywords[input_names[0]],
                                           ITensor):
      op_prefix = op_type_name + "."+keywords[input_names[0]].dtype.name
    else:
      op_prefix = op_type_name + ".no_dtype"

    # keywords for original apply_op, ITensor entries will be replaced with
    # Tensors
    opdeflib_keywords = dict(keywords)

    # Graph construction 1/3: inputs
    # replace ITensor inputs with tensorhandle->tensor converters
    with self.env.g.as_default():
      input_holders = {}  # placeholders for string tensor handle feeding

      for input_num, input_name in enumerate(sorted(input_names)):
        op_name = op_prefix + "." + str(input_num)
        itensor_input = keywords[input_name]
        # single tensor input
        if isinstance(itensor_input, ITensor):
          holder, tensor = session_ops.get_session_tensor(
              itensor_input.tf_handle, itensor_input.dtype, name=op_name)
          input_holders[input_name] = holder
          opdeflib_keywords[input_name] = tensor

        # list input, such as for tf.concat, add converter for each element
        else:
          assert is_list_or_tuple(itensor_input)
          holder_list = []
          tensor_list = []
          for subinput_num, subinput in enumerate(itensor_input):
            op_name = op_name + "_" + str(subinput_num)
            holder, tensor = session_ops.get_session_tensor(subinput.tf_handle,
                                                            subinput.dtype,
                                                            name=op_name)
            holder_list.append(holder)
            tensor_list.append(tensor)
            opdeflib_keywords[input_name] = tensor_list
          input_holders[input_name] = holder_list

      # Graph construction 2/3: op
      # call original apply_op to create the op
      output = self.original_op_def_library.apply_op(op_type_name,
                                                     name=op_prefix+".op",
                                                     **opdeflib_keywords)


      # Graph construction 3: outputs
      # attach tensor->tensorhandle conversion to outputs
      op_name = op_prefix+".out"

      # single Tensor output
      if isinstance(output, ops_lib.Tensor):
        output_handle = session_ops.get_session_handle(output,
                                                       op_name+".handle")
      # operation output like with.control_dependencies
      elif isinstance(output, ops_lib.Operation):
        assert False, "Imperative mode only supports ops that produce tensors"

      else:  # list of Tensors, such as for tf.split
        assert is_list_or_tuple(output)
        output_handle = []
        for output_num, output_tensor in enumerate(output):
          op_name = op_name + "_" + str(output_num)
          output_single_handle = session_ops.get_session_handle(output_tensor,
                                                                (op_name+
                                                                 ".handle"))
          output_handle.append(output_single_handle)

    # save our newly created op in cache
    op = Op(self.env, input_holders, output_handle)
    self.env.cache_add(key, op)

    # execute the op
    return op(**keywords)
示例#38
0
    def cont(self,
             target,
             use_tensor_handles=True,
             use_overrides=True,
             restore_variable_values=False):
        """Continue till the completion of the specified target tensor.

    Args:
      target: A single fetched Tensor or Op, or a name (str) representing the
        Tensor or Op. In the case of a name str, the graph will be searched
        to find the corresponding Tensor or Op.
        # TODO(cais): Support multiple fetches as in Session.run() interface.
      use_tensor_handles: (bool) Whether this cont() run will use cached tensor
        handles to avoid recomputation. Default: True.
      use_overrides: (bool) Whether the overriding tensor values supplied by
        the client are to be used in this cont() call. Default: True.
      restore_variable_values: (bool) Whether the old values of the variables
        (before any cont() calls in this object) are to be restored.

    Returns:
      Value from Session.run() of the target.

    Raises:
      ValueError: If the target is specified as a string and the string does
        not correspond to any tensors in the Session graph.
        Or if the target of this cont() is not in the input list of the Stepper
        object's target.
        Or if target is a Placeholder.
    """

        self._last_feed_types = {}

        # The feeds to be used in the Session.run() call.
        feeds = {}

        if isinstance(target, six.string_types):
            # Fetch target is a string. Assume it is the name of the Tensor or Op and
            # will attempt to find it in the Session's graph.
            target_name = target
        else:
            target_name = target.name

        graph_element = self._sess.graph.as_graph_element(target_name)
        # Any additional tensor handles to obtain in this cont() action.
        additional_handle_requests = []

        if (isinstance(graph_element, ops.Tensor)
                and graph_element.op.type == "Placeholder"):
            self._last_feed_types[graph_element.name] = self.FEED_TYPE_CLIENT
            return self._client_feed_dict[graph_element.name]
        elif (isinstance(graph_element, ops.Operation)
              and graph_element.type == "Placeholder"):
            tensor_name = graph_element.name + ":0"
            self._last_feed_types[tensor_name] = self.FEED_TYPE_CLIENT
            return self._client_feed_dict[tensor_name]

        if isinstance(graph_element, ops.Operation) and graph_element.outputs:
            # Check if this op has any output tensors that also fall into this
            # stepper's transitive closure.
            node_outputs = [
                output.name for output in graph_element.outputs
                if output.name in self._closure_elements
            ]
            if node_outputs:
                # The target is an op with at least one output within the transitive
                # closure. The cont() action will amount to using the 0-th
                # output Tensor as the target, as well as obtaining handles to it
                # and to the rest of the outputs tensors in the transitive closure
                # (if any).
                target_name = node_outputs[0]
                additional_handle_requests = node_outputs[1:]

        # Verify that the target is in the transitive closure of the stepper's
        # fetch.
        target_node_name = self._get_node_name(target_name)
        if target_node_name not in self._transitive_closure_set:
            raise ValueError(
                "Target \"%s\" is not in the transitive closure for the fetch of the "
                "stepper: \"%s\"." % (target_name, repr(self._fetch_names)))

        # Check if a cached tensor handle can be used on the fetch directly.
        if use_tensor_handles and target_name in self._tensor_handles:
            self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE
            return self._tensor_handles[target_name].eval()

        # Check if an overriding tensor value can be used directly.
        if use_overrides and target_name in self._override_tensors:
            # Override is available. Return the value right away.
            self._last_feed_types[target_name] = self.FEED_TYPE_OVERRIDE
            return self._override_tensors[target_name]

        # Keep track of which variables are restored in this cont() call.
        restored_variables = set()

        # Keep track of which variables are "touched" (i.e., possibly updated) in
        # this cont() call.
        touched_variables = set()

        # =========================================================================
        # Use a non-recursive method to trace the inputs from the node and set up
        # the feeds.
        fetched = self._sess.graph.as_graph_element(target_name)
        elem_stack = [fetched]
        done = set()

        while elem_stack:
            curr_elem = elem_stack.pop()
            curr_node = self._get_node(curr_elem)

            done.add(curr_node.name)

            non_control_inputs = [inp for inp in curr_node.inputs]
            control_inputs = [inp for inp in curr_node.control_inputs]
            all_inputs = set(non_control_inputs + control_inputs)

            # Iterate through the (non-control) inputs.
            for inp in all_inputs:
                # Determine whether the input is feedable. Reference-type tensors,
                # e.g., Variables, should not be fed, because they can change.
                if isinstance(inp, ops.Tensor):
                    is_inp_ref = inp.dtype._is_ref_dtype  # pylint: disable=protected-access
                    can_feed = self._sess.graph.is_feedable(
                        inp) and not is_inp_ref
                else:
                    is_inp_ref = False
                    can_feed = False

                if (restore_variable_values
                        and inp.name in self._dirty_variables
                        and inp.name not in restored_variables
                        and inp.name not in touched_variables):
                    # Do not restore Variables touched or restored previously in this
                    # cont() call.
                    initializer_op = self._variable_initializers[inp.name]
                    initial_value_tensor = self._variable_initial_values[
                        inp.name]
                    self._sess.run(initializer_op,
                                   feed_dict={
                                       initial_value_tensor:
                                       self._cached_variable_values[inp.name]
                                   })

                    # Mark the variable as restored.
                    restored_variables.add(inp.name)

                # Determine if this is a reference-type input from a variable, and
                # the recipient node is not Identity. In that case, the Variable
                # needs to be marked as dirty and its current value recorded, due to
                # the fact that the receiving op may mutate the value of the Variable.
                if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"]
                        and curr_node.type != "Identity"):
                    # Mark the variable as dirty.
                    touched_variables.add(inp.name)

                    # Obtain the old value of the variable and cache it.
                    if inp.name not in self._cached_variable_values:
                        old_value = self._sess.run(inp)
                        self._cached_variable_values[inp.name] = old_value

                # N.B.: The order of the logical branches matters. For example,
                # _client_feed_dict comes after _tensor_handles, so that tensor
                # handles stored in cont() calls can override the original client
                # feeds. Also for example, _override_tensors comes the first, so
                # the manual overriding, if exists, can always take effect.
                if use_overrides and can_feed and inp.name in self._override_tensors:
                    # Use client-supplied overriding tensor value.
                    feeds[inp] = self._override_tensors[inp.name]
                    self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE
                elif (use_tensor_handles and can_feed
                      and inp.name in self._tensor_handles
                      and inp not in feeds):
                    # Tensor handle found in cache.
                    feeds[inp] = self._tensor_handles[inp.name].eval()
                    self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE
                elif inp.name in self._client_feed_dict:
                    # This input is available in the client feed_dict.
                    feeds[inp] = self._client_feed_dict[inp.name]
                    self._last_feed_types[inp.name] = self.FEED_TYPE_CLIENT
                else:
                    # There is no feed available for this input. So keep tracing its
                    # input(s).
                    inp_node = self._get_node(inp)
                    if inp_node.name in done:
                        # Already visited.
                        continue

                    elem_stack.append(inp)
                    done.add(inp_node.name)

        # =========================================================================

        if touched_variables:
            self._dirty_variables.update(touched_variables)

        for variable in restored_variables:
            self._dirty_variables.remove(variable)

        # Prepare RunOptions for DebugTensorWatches
        run_options = config_pb2.RunOptions()
        # TODO(cais): Add fields for watching intermediate tensors.

        if isinstance(fetched, ops.Operation):
            # The fetched is an Operation: Will not get tensor handle.
            self._sess.run(fetched, feed_dict=feeds, options=run_options)
            # No return value for a run of an Operation
        else:
            # This is a Tensor: Will get tensor handle and cache it.
            # Will also get the additional requested tensor handles (if any).
            tensors_to_get_handles_for = [fetched]
            handle_names = [target_name]

            tensors_to_get_handles_for.extend([
                self._sess.graph.as_graph_element(h)
                for h in additional_handle_requests
            ])
            handle_names.extend(additional_handle_requests)

            for handle_name, tensor in zip(handle_names,
                                           tensors_to_get_handles_for):
                handle = self._sess.run(session_ops.get_session_handle(tensor),
                                        feed_dict=feeds,
                                        options=run_options)
                self._tensor_handles[handle_name] = handle

            return self._tensor_handles[target_name].eval()

        # Invalidate caches at the end.
        for touched_variable in touched_variables:
            self._invalidate_transitively_outgoing_cache(touched_variable)
示例#39
0
  def cont(self,
           target,
           use_tensor_handles=True,
           use_overrides=True,
           restore_variable_values=False):
    """Continue till the completion of the specified target tensor.

    Args:
      target: A single fetched Tensor or Op, or a name (str) representing the
        Tensor or Op. In the case of a name str, the graph will be searched
        to find the corresponding Tensor or Op.
        # TODO(cais): Support multiple fetches as in Session.run() interface.
      use_tensor_handles: (bool) Whether this cont() run will use cached tensor
        handles to avoid recomputation. Default: True.
      use_overrides: (bool) Whether the overriding tensor values supplied by
        the client are to be used in this cont() call. Default: True.
      restore_variable_values: (bool) Whether the old values of the variables
        (before any cont() calls in this object) are to be restored.

    Returns:
      Value from Session.run() of the target.

    Raises:
      ValueError: If the target is specified as a string and the string does
        not correspond to any tensors in the Session graph.
        Or if the target of this cont() is not in the input list of the Stepper
        object's target.
        Or if target is a Placeholder.
    """

    self._last_feed_types = {}

    # The feeds to be used in the Session.run() call.
    feeds = {}

    if isinstance(target, six.string_types):
      # Fetch target is a string. Assume it is the name of the Tensor or Op and
      # will attempt to find it in the Session's graph.
      target_name = target
    else:
      target_name = target.name

    graph_element = self._sess.graph.as_graph_element(target_name)
    # Any additional tensor handles to obtain in this cont() action.
    additional_handle_requests = []

    if (isinstance(graph_element, ops.Tensor) and
        graph_element.op.type == "Placeholder"):
      self._last_feed_types[graph_element.name] = self.FEED_TYPE_CLIENT
      return self._client_feed_dict[graph_element.name]
    elif (isinstance(graph_element, ops.Operation) and
          graph_element.type == "Placeholder"):
      tensor_name = graph_element.name + ":0"
      self._last_feed_types[tensor_name] = self.FEED_TYPE_CLIENT
      return self._client_feed_dict[tensor_name]

    if isinstance(graph_element, ops.Operation) and graph_element.outputs:
      # Check if this op has any output tensors that also fall into this
      # stepper's transitive closure.
      node_outputs = [
          output.name for output in graph_element.outputs
          if output.name in self._closure_elements
      ]
      if node_outputs:
        # The target is an op with at least one output within the transitive
        # closure. The cont() action will amount to using the 0-th
        # output Tensor as the target, as well as obtaining handles to it
        # and to the rest of the outputs tensors in the transitive closure
        # (if any).
        target_name = node_outputs[0]
        additional_handle_requests = node_outputs[1:]

    # Verify that the target is in the transitive closure of the stepper's
    # fetch.
    target_node_name = self._get_node_name(target_name)
    if target_node_name not in self._transitive_closure_set:
      raise ValueError(
          "Target \"%s\" is not in the transitive closure for the fetch of the "
          "stepper: \"%s\"." % (target_name, repr(self._fetch_names)))

    # Check if a cached tensor handle can be used on the fetch directly.
    if use_tensor_handles and target_name in self._tensor_handles:
      self._last_feed_types[target_name] = self.FEED_TYPE_HANDLE
      return self._tensor_handles[target_name].eval()

    # Check if an overriding tensor value can be used directly.
    if use_overrides and target_name in self._override_tensors:
      # Override is available. Return the value right away.
      self._last_feed_types[target_name] = self.FEED_TYPE_OVERRIDE
      return self._override_tensors[target_name]

    # Keep track of which variables are restored in this cont() call.
    restored_variables = set()

    # Keep track of which variables are "touched" (i.e., possibly updated) in
    # this cont() call.
    touched_variables = set()

    # =========================================================================
    # Use a non-recursive method to trace the inputs from the node and set up
    # the feeds.
    fetched = self._sess.graph.as_graph_element(target_name)
    elem_stack = [fetched]
    done = set()

    while elem_stack:
      curr_elem = elem_stack.pop()
      curr_node = self._get_node(curr_elem)

      done.add(curr_node.name)

      non_control_inputs = [inp for inp in curr_node.inputs]
      control_inputs = [inp for inp in curr_node.control_inputs]
      all_inputs = set(non_control_inputs + control_inputs)

      # Iterate through the (non-control) inputs.
      for inp in all_inputs:
        # Determine whether the input is feedable. Reference-type tensors,
        # e.g., Variables, should not be fed, because they can change.
        if isinstance(inp, ops.Tensor):
          is_inp_ref = inp.dtype._is_ref_dtype   # pylint: disable=protected-access
          can_feed = self._sess.graph.is_feedable(inp) and not is_inp_ref
        else:
          is_inp_ref = False
          can_feed = False

        if (restore_variable_values and inp.name in self._dirty_variables and
            inp.name not in restored_variables and
            inp.name not in touched_variables):
          # Do not restore Variables touched or restored previously in this
          # cont() call.
          initializer_op = self._variable_initializers[inp.name]
          initial_value_tensor = self._variable_initial_values[inp.name]
          self._sess.run(initializer_op,
                         feed_dict={
                             initial_value_tensor:
                                 self._cached_variable_values[inp.name]
                         })

          # Mark the variable as restored.
          restored_variables.add(inp.name)

        # Determine if this is a reference-type input from a variable, and
        # the recipient node is not Identity. In that case, the Variable
        # needs to be marked as dirty and its current value recorded, due to
        # the fact that the receiving op may mutate the value of the Variable.
        if (is_inp_ref and inp.op.type in ["Variable", "VariableV2"] and
            curr_node.type != "Identity"):
          # Mark the variable as dirty.
          touched_variables.add(inp.name)

          # Obtain the old value of the variable and cache it.
          if inp.name not in self._cached_variable_values:
            old_value = self._sess.run(inp)
            self._cached_variable_values[inp.name] = old_value

        # N.B.: The order of the logical branches matters. For example,
        # _client_feed_dict comes after _tensor_handles, so that tensor
        # handles stored in cont() calls can override the original client
        # feeds. Also for example, _override_tensors comes the first, so
        # the manual overriding, if exists, can always take effect.
        if use_overrides and can_feed and inp.name in self._override_tensors:
          # Use client-supplied overriding tensor value.
          feeds[inp] = self._override_tensors[inp.name]
          self._last_feed_types[inp.name] = self.FEED_TYPE_OVERRIDE
        elif (use_tensor_handles and can_feed and
              inp.name in self._tensor_handles and inp not in feeds):
          # Tensor handle found in cache.
          feeds[inp] = self._tensor_handles[inp.name].eval()
          self._last_feed_types[inp.name] = self.FEED_TYPE_HANDLE
        elif inp.name in self._client_feed_dict:
          # This input is available in the client feed_dict.
          feeds[inp] = self._client_feed_dict[inp.name]
          self._last_feed_types[inp.name] = self.FEED_TYPE_CLIENT
        else:
          # There is no feed available for this input. So keep tracing its
          # input(s).
          inp_node = self._get_node(inp)
          if inp_node.name in done:
            # Already visited.
            continue

          elem_stack.append(inp)
          done.add(inp_node.name)

    # =========================================================================

    if touched_variables:
      self._dirty_variables.update(touched_variables)

    for variable in restored_variables:
      self._dirty_variables.remove(variable)

    # Prepare RunOptions for DebugTensorWatches
    run_options = config_pb2.RunOptions()
    # TODO(cais): Add fields for watching intermediate tensors.

    if isinstance(fetched, ops.Operation):
      # The fetched is an Operation: Will not get tensor handle.
      self._sess.run(fetched, feed_dict=feeds, options=run_options)
      # No return value for a run of an Operation
    else:
      # This is a Tensor: Will get tensor handle and cache it.
      # Will also get the additional requested tensor handles (if any).
      tensors_to_get_handles_for = [fetched]
      handle_names = [target_name]

      tensors_to_get_handles_for.extend([
          self._sess.graph.as_graph_element(h)
          for h in additional_handle_requests
      ])
      handle_names.extend(additional_handle_requests)

      for handle_name, tensor in zip(handle_names, tensors_to_get_handles_for):
        handle = self._sess.run(session_ops.get_session_handle(tensor),
                                feed_dict=feeds,
                                options=run_options)
        self._tensor_handles[handle_name] = handle

      return self._tensor_handles[target_name].eval()

    # Invalidate caches at the end.
    for touched_variable in touched_variables:
      self._invalidate_transitively_outgoing_cache(touched_variable)