def testGradSerialTwoLoops(self):
    with self.test_session(use_gpu=True):
      num_steps = 100
      acc = tensor_array_ops.TensorArray(
      i = constant_op.constant(0, name="i")
      x = constant_op.constant(2.0, name="x")

      c = lambda i, acc: i < 5

      def b(i, acc):
        x1 = control_flow_ops.cond(
            math_ops.equal(i, 0), lambda: x,
            lambda: math_ops.multiply( - 1), 2.0))
        return i + 1, acc.write(i, x1)

      i1, acc1 = control_flow_ops.while_loop(c, b, [i, acc])

      z = constant_op.constant(0.0)

      def fn(i, acc):
        return i + 1, acc.write(i, z)

      _, acc2 = control_flow_ops.while_loop(lambda i, acc: i < num_steps, fn,
                                            [i1, acc1])

      r = acc2.stack()
      grad = gradients_impl.gradients(r, [x])[0]
      self.assertAllClose(31.0, grad.eval())
  def _testStackWhileSwap(self, use_gpu):
    with self.test_session(use_gpu=use_gpu):
      n = constant_op.constant(0)
      h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")

      def c(x):
        return math_ops.less(x, 10)

      def b(x):
        with ops.control_dependencies([x]):
          a = constant_op.constant(np.ones(2000), dtype=dtypes.float32)
          v = gen_data_flow_ops._stack_push(h, a, swap_memory=True)
        with ops.control_dependencies([v]):
          return math_ops.add(x, 1)

      r = control_flow_ops.while_loop(c, b, [n])

      v = constant_op.constant(np.zeros(2000), dtype=dtypes.float32)

      def c1(x, y):
        return math_ops.greater(x, 0)

      def b1(x, y):
        nx = math_ops.subtract(x, 1)
        ny = y + gen_data_flow_ops._stack_pop(h, dtypes.float32)
        return [nx, ny]

      rx, ry = control_flow_ops.while_loop(
          c1, b1, [r, v], [r.get_shape(), tensor_shape.unknown_shape()])
      self.assertAllClose(np.ones(2000) * 10.0, ry.eval())
 def testWhileContext(self):
     with self.test_session() as sess:
         i = constant_op.constant(0)
         c = lambda i: math_ops.less(i, 10)
         b = lambda i: math_ops.add(i, 1)
         control_flow_ops.while_loop(c, b, [i])
         for op in sess.graph.get_operations():
             c = op._get_control_flow_context()
             if c:
                 compare.ProtoEq(c.to_proto(), control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
  def testControlFlowInitialization(self):
    """Expects an error if an initializer is in a control-flow scope."""
    def cond(i, _):
      return i < 10

    def body(i, _):
      zero = array_ops.zeros([], dtype=dtypes.int32)
      v = variables.Variable(initial_value=zero)
      return (i + 1, v.read_value())

    with self.assertRaisesRegexp(ValueError, "inside a control-flow"):
      control_flow_ops.while_loop(cond, body, [0, 0])
  def testGradientInsideLoop(self):
    with ops.Graph().as_default():
      v = resource_variable_ops.ResourceVariable(1.0)

      def body(_):
        _ = v + 1.0  # This reads the variable inside the loop context
        with backprop.GradientTape() as t:
          result = v * 2
        self.assertTrue(t.gradient(result, v) is not None)
        return 1.0

      control_flow_ops.while_loop(lambda i: False, body, [1.0])
 def _testWhileContextHelper(self, maximum_iterations=None):
   with self.test_session() as sess:
     i = constant_op.constant(0)
     c = lambda i: math_ops.less(i, 10)
     b = lambda i: math_ops.add(i, 1)
         c, b, [i], maximum_iterations=maximum_iterations)
     for op in sess.graph.get_operations():
       context = op._get_control_flow_context()
       if context:
  def _get_enqueue_op_per_host(self, host_id, multi_worker_iterator,
                               input_shapes, iterations):
    """Create an enqueue op for a single host identified using host_id.

    The while_loop op returned will run `iterations` times and in each run
    enqueue batches for each shard.

      host_id: integer, id of the host to run the enqueue ops on.
      multi_worker_iterator: MultiWorkerDataIterator to read the input data.
      input_shapes: shape of inputs to be enqueue on the queue. This is same as
        the value of `nest.flatten(iterator.output_shapes)`.
      iterations: integer, number of iterations to be run; determines the
        number of batches to be enqueued.

      while_loop_op running `iterations` times; in each run we enqueue a batch
      on the infeed queue from the host with id `host_id` for each device shard.
    host = self.get_host_cpu_device(host_id)
    # TODO(sourabhbajaj): Possibly make changes to MultiWorkerDataset
    # to work with TPU Prefetch so clean up this code.
    iterator = (
        multi_worker_iterator.get_iterator(self.get_host(host_id))._iterator)  # pylint: disable=protected-access

    def _infeed_enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      enqueue_ops = []

      with ops.device(host):
        for _ in range(self.num_replicas_per_host):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())

      for core_id, shard_input in enumerate(sharded_inputs):
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      """Callable for the loop body of the while_loop instantiated below."""
      with ops.control_dependencies(_infeed_enqueue_ops_fn()):
        return i + 1

    with ops.device(host):
      enqueue_op_per_host = control_flow_ops.while_loop(
          lambda i: i < iterations,

    return enqueue_op_per_host
def _maximal_eigenvector_power_method(matrix,
  """Returns the maximal right-eigenvector of `matrix` using the power method.

    matrix: 2D Tensor, the matrix of which we will find the maximal
    epsilon: nonnegative float, if two iterations of the power method differ (in
      L2 norm) by no more than epsilon, we will terminate.
    maximum_iterations: nonnegative int, if we perform this many iterations, we
      will terminate.

    The maximal right-eigenvector of `matrix`.

    ValueError: If the `matrix` tensor is not floating-point, or if the
      `epsilon` or `maximum_iterations` parameters violate their bounds.
  if not matrix.dtype.is_floating:
    raise ValueError("multipliers must have a floating-point dtype")
  if epsilon <= 0.0:
    raise ValueError("epsilon must be strictly positive")
  if maximum_iterations <= 0:
    raise ValueError("maximum_iterations must be strictly positive")

  def while_loop_condition(iteration, eigenvector, old_eigenvector):
    """Returns false if the while loop should terminate."""
    not_done = (iteration < maximum_iterations)
    not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon)
    return standard_ops.logical_and(not_done, not_converged)

  def while_loop_body(iteration, eigenvector, old_eigenvector):
    """Performs one iteration of the power method."""
    del old_eigenvector  # Needed by the condition, but not the body.
    iteration += 1
    # We need to use tf.matmul() and tf.expand_dims(), instead of
    # tf.tensordot(), since the former will infer the shape of the result, while
    # the latter will not (tf.while_loop() needs the shapes).
    new_eigenvector = standard_ops.matmul(
        matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0]
    new_eigenvector /= standard_ops.norm(new_eigenvector)
    return (iteration, new_eigenvector, eigenvector)

  iteration = standard_ops.constant(0)
  eigenvector = standard_ops.ones_like(matrix[:, 0])
  eigenvector /= standard_ops.norm(eigenvector)

  # We actually want a do-while loop, so we explicitly call while_loop_body()
  # once before tf.while_loop().
  iteration, eigenvector, old_eigenvector = while_loop_body(
      iteration, eigenvector, eigenvector)
  iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop(
      loop_vars=(iteration, eigenvector, old_eigenvector),

  return eigenvector
 def create_while_loop():
   r = control_flow_ops.while_loop(
       lambda *_: True,
       outer_body, (0, 1.0),
   return array_ops.identity(r[1])
def gcd(a, b, name=None):
  """Returns the greatest common divisor via Euclid's algorithm.

    a: The dividend. A scalar integer `Tensor`.
    b: The divisor. A scalar integer `Tensor`.
    name: An optional name for the operation.

    A scalar `Tensor` representing the greatest common divisor between `a` and

    ValueError: If `a` or `b` are not scalar integers.
  with ops.name_scope(name, 'gcd', [a, b]):
    a = ops.convert_to_tensor(a)
    b = ops.convert_to_tensor(b)


    if not a.dtype.is_integer:
      raise ValueError('a must be an integer type. Got: %s' % a.dtype)
    if not b.dtype.is_integer:
      raise ValueError('b must be an integer type. Got: %s' % b.dtype)

    cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b))
    body = lambda a, b: [b, math_ops.mod(a, b)]
    a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False)
    return a
  def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
    for dtype in [dtypes.float32, dtypes.float64]:
      with self.test_session() as sess:
        inputs = array_ops.placeholder(dtype=dtype)
        initial_outputs = tensor_array_ops.TensorArray(
            dtype=dtype, dynamic_size=True, size=1)
        initial_i = constant_op.constant(0, dtype=dtypes.int32)

        def Cond(i, _):
          return i < array_ops.size(inputs)  # pylint: disable=cell-var-from-loop

        def Body(i, outputs):
          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
          outputs = outputs.write(i, x)
          return i + 1, outputs

        _, outputs = control_flow_ops.while_loop(Cond, Body,
                                                 [initial_i, initial_outputs])

        outputs = math_ops.reduce_sum(outputs.stack())
        r = gradients_impl.gradients([outputs], [inputs])[0]
        grad_wr_inputs = ops.convert_to_tensor(r)
        o, grad =[outputs, grad_wr_inputs],
                           feed_dict={inputs: [1, 3, 2]})
        self.assertEquals(o, 6)
        self.assertAllEqual(grad, [1] * 3)
def _repeat_range(counts, name=None):
  """Repeat integers given by range(len(counts)) each the given number of times.

  Example behavior:
  [0, 1, 2, 3] -> [1, 2, 2, 3, 3, 3]

    counts: 1D tensor with dtype=int32.
    name: optional name for operation.

    1D tensor with dtype=int32 and dynamic length giving the repeated integers.
  with ops.name_scope(name, 'repeat_range', [counts]) as scope:
    counts = ops.convert_to_tensor(counts, name='counts')

    def cond(unused_output, i):
      return i < size

    def body(output, i):
      value = array_ops.fill(counts[i:i+1], i)
      return (output.write(i, value), i + 1)

    size = array_ops.shape(counts)[0]
    init_output_array = tensor_array_ops.TensorArray(
        dtype=dtypes.int32, size=size, infer_shape=False)
    output_array, num_writes = control_flow_ops.while_loop(
        cond, body, loop_vars=[init_output_array, 0])

    return control_flow_ops.cond(
        num_writes > 0,
        lambda: array_ops.zeros(shape=[0], dtype=dtypes.int32),
 def _timeit(iterations, _):
   (_, final) = control_flow_ops.while_loop(
       lambda t, _: t < iterations,
       body, (t0, v0),
   return [final]
  def createAndRunGraphWithWhileLoop(self):
    """Create and run a TensorFlow Graph with a while loop to generate dumps."""

    self.dump_root = self.get_temp_dir()
    self.curr_file_path = os.path.abspath(

    # Run a simple TF graph to generate some debug dumps that can be used in
    # source annotation.
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      self.traceback_first_line = line_number_above()

      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
          run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
      run_metadata = config_pb2.RunMetadata(), options=run_options, run_metadata=run_metadata)

      self.dump = debug_data.DebugDumpDir(
          self.dump_root, partition_graphs=run_metadata.partition_graphs)
def _do_maximum_mean(samples, envelope, high, name=None):
  """Common code between maximum_mean and minimum_mean."""
  with ops.name_scope(name, "do_maximum_mean", [samples, envelope, high]):
    n = array_ops.rank(samples)
    # Move the batch dimension of `samples` to the rightmost position,
    # where the _batch_sort_vector function wants it.
    perm = array_ops.concat([math_ops.range(1, n), [0]], axis=0)
    samples = array_ops.transpose(samples, perm)

    samples = _batch_sort_vector(samples)
    batch_shape = array_ops.shape(samples)[:-1]
    n = array_ops.shape(samples)[-1]
    step = 1. / math_ops.cast(n, dtype=samples.dtype.base_dtype)

    def _loop_body(iter_, total, to_skip):
      total = array_ops.where(
          step <= to_skip,
              to_skip > 0.,
              total + (step - to_skip) * samples[..., iter_],
              total + step * samples[..., iter_]))
      to_skip = array_ops.where(step <= to_skip, to_skip - step, 0.)
      return [iter_ + 1, total, to_skip]

    _, total, _ = control_flow_ops.while_loop(
        cond=lambda iter_, *args: iter_ < n,
            array_ops.zeros(batch_shape, dtype=samples.dtype.base_dtype),
            envelope,  # to_skip

  return total + envelope * high
 def _forward(self, x):
   event_size = array_ops.shape(x)[-1]
   y0 = array_ops.zeros_like(x, name="y0")
   # call the template once to ensure creation
   _ = self._shift_and_log_scale_fn(y0)
   def _loop_body(index, y0):
     """While-loop body for autoregression calculation."""
     # Set caching device to avoid re-getting the tf.Variable for every while
     # loop iteration.
     with variable_scope_lib.variable_scope(
         variable_scope_lib.get_variable_scope()) as vs:
       if vs.caching_device is None:
         vs.set_caching_device(lambda op: op.device)
       shift, log_scale = self._shift_and_log_scale_fn(y0)
     y = x
     if log_scale is not None:
       y *= math_ops.exp(log_scale)
     if shift is not None:
       y += shift
     return index + 1, y
   _, y = control_flow_ops.while_loop(
       cond=lambda index, _: index < event_size,
       loop_vars=[0, y0])
   return y
  def testIndexedSlicesWithShapeGradientInWhileLoop(self):
    for dtype in [dtypes.float32, dtypes.float64]:
      with self.test_session() as sess:
        num_steps = 9

        inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
        initial_outputs = tensor_array_ops.TensorArray(
            dtype=dtype, size=num_steps)
        initial_i = constant_op.constant(0, dtype=dtypes.int32)

        def cond(i, _):
          return i < num_steps  # pylint: disable=cell-var-from-loop

        def body(i, outputs):
          x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
          outputs = outputs.write(i, x)
          return i + 1, outputs

        _, outputs = control_flow_ops.while_loop(cond, body,
                                                 [initial_i, initial_outputs])

        outputs = math_ops.reduce_sum(outputs.stack())
        r = gradients_impl.gradients([outputs], [inputs])[0]
        grad_wr_inputs = ops.convert_to_tensor(r)
        o, grad =[outputs, grad_wr_inputs],
                           feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
        self.assertEquals(o, 20)
        self.assertAllEqual(grad, [1] * num_steps)
def _loop_vars_intertwined(x0, y0, functor_x, functor_y):
  """Loop whose loop variables are intertwined."""
  c = lambda i, j, x, y: j < 4
  b = lambda i, j, x, y: (j + 1, i + 1, functor_y(y), functor_x(x))
  init = (constant_op.constant(0), constant_op.constant(0), x0, y0)
  ijzw = control_flow_ops.while_loop(c, b, init)
  return ijzw
  def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
    with ops.Graph().as_default():
      embedding_matrix = variable_scope.get_variable(
          "embedding_matrix", [5, 5],

      def Cond(it, _):
        return it < 5

      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
        cost = control_flow_ops.cond(
            math_ops.equal(it, 3), lambda: math_ops.square(cost),
            lambda: cost + math_ops.reduce_sum(embedding))
        return it + 1, cost

      _, cost = control_flow_ops.while_loop(
          Cond, Body, [constant_op.constant(0), constant_op.constant(0.0)])

      dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
      dynamic_grads = math_ops.segment_sum(dynamic_grads.values,

      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      static = math_ops.square(
          math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
          math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
      static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
      static_grads = math_ops.segment_sum(static_grads.values,

      with self.test_session() as sess:
        self.assertAllEqual(*[static_grads, dynamic_grads]))
  def testIndexedSlicesGradientInCondInWhileLoop(self):
    with ops.Graph().as_default():
      embedding_matrix = tf.get_variable(
          "embedding_matrix", [5, 5],

      def Cond(it, _):
        return it < 5
      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
        cost = tf.cond(tf.equal(it, 3),
                       lambda: tf.square(cost),
                       lambda: cost + tf.reduce_sum(embedding))
        return it + 1, cost
      _, cost = control_flow_ops.while_loop(
          Cond, Body, [tf.constant(0), tf.constant(0.0)])

      dynamic_grads = tf.gradients(cost, [embedding_matrix])[0]
      dynamic_grads = tf.segment_sum(dynamic_grads.values,

      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      static = tf.square(
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding) +
          tf.reduce_sum(embedding)) + tf.reduce_sum(embedding)
      static_grads = tf.gradients(static, [embedding_matrix])[0]
      static_grads = tf.segment_sum(static_grads.values, static_grads.indices)

      with self.test_session() as sess:
        self.assertAllEqual(*[static_grads, dynamic_grads]))
  def testDifferentShapesGraph(self):
    # Tests that a single kernel instance presented with multiple input shapes
    # does not crash with graph execution.
    with ops.device("gpu:0"):
      layer = cudnn_rnn.CudnnGRU(1, 100)
      layer(array_ops.zeros([28, 100, 100]))

      def _Cond(index, accumulation):
        del accumulation  # unused
        return math_ops.less(index, 4)

      def _Body(index, accumulation):
        layer_input = accumulation[:, :, 10 * (1 + index % 2):]
        output, _ = layer(layer_input)
        return index + 1, accumulation + output

      original_input = array_ops.zeros([28, 100, 100])
      _, accumulation = control_flow_ops.while_loop(_Cond, _Body,
                                                    [0, original_input])
      grad, = gradients.gradients(
          math_ops.reduce_sum(accumulation), (original_input,))
    init_op = variables.global_variables_initializer()
    with self.test_session() as sess:
      accumulation_eval, grad_eval =, grad))
      self.assertAllEqual([28, 100, 100], accumulation_eval.shape)
      self.assertAllEqual([28, 100, 100], grad_eval.shape)
def _simple_loop(x, functor):
  """Simple loop whose body is provided by the functor."""
  init = (constant_op.constant(0), x)
  c = lambda i, j: i < 4
  b = lambda i, j: (i + 1, functor(j))
  ij = control_flow_ops.while_loop(c, b, init)
  return ij
def run_while(cond_fn, body_fn, init_args):
  """Type-dependent functional while loop.

    cond_fn: A Python callable implementing the stop conditions of the loop.
    body_fn: A Python callable implementing the body of the loop.
    init_args: The initial values of the arguments that will be passed to both
      cond_fn and body_fn.

    result: A list of values with the same shape and type as init_args. If any
    of the init_args, or any variables closed-over in cond_fn are Tensors,
    tf.while_loop will be used, otherwise a Python while loop will be ran.

    ValueError: if init_args is not a tuple or list with one or more elements.
  if not isinstance(init_args, (tuple, list)) or not init_args:
    raise ValueError(
        'init_args must be a non-empty list or tuple, found %s' % init_args)

  # TODO(alexbw): statically determine all active variables in cond_fn,
  # and pass them directly
  closure_vars = tuple(
      [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
  possibly_tensors = tuple(init_args) + closure_vars
  if is_tensor(*possibly_tensors):
    return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
    return py_while_loop(cond_fn, body_fn, init_args)
  def test_while_jacobian(self):
    x = random_ops.random_uniform([1, 3])
    y = random_ops.random_uniform([3, 3])

    # out = x @ y @ y @ y @ y, where @ is matmul operator.
    _, out = control_flow_ops.while_loop(
        lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
        [0, x])

    def loop_fn(i):
      out_i = array_ops.gather(out, i, axis=1)
      return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1])

    out = pfor_control_flow_ops.pfor(loop_fn, iters=3)

    # The above code does not work with tf.while_loop instead of pfor. So we
    # manually compute the expected output here.
    # Note that gradient of output w.r.t is (y @ y @ y @ y)^T.
    expected_output = y
    for _ in range(3):
      expected_output = math_ops.matmul(expected_output, y)
    expected_output = array_ops.transpose(expected_output, [1, 0])

    with session.Session() as sess:
      out, expected =[out, expected_output])
      self.assertAllClose(expected, out)
  def loop_fn(i):
    sequence_length_i = array_ops.gather(sequence_length, i)

    def body_fn(t, state, ta):
      inputs_t = array_ops.expand_dims(
          array_ops.gather(, i), 0)
      output, new_state = cell(inputs_t, state)
      output = array_ops.reshape(output, [-1])
      # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
      # array_ops.where when t < min(sequence_length). Doing that requires
      # supporting tf.cond pfor conversion.
      done = t >= sequence_length_i
      output = array_ops.where(done, zeros, output)
      ta = ta.write(t, output)
      new_state = [array_ops.where(done, s, ns) for s, ns in
                   zip(nest.flatten(state), nest.flatten(new_state))]
      new_state = nest.pack_sequence_as(state, new_state)
      return t + 1, new_state, ta

    def condition_fn(t, _, unused):
      del unused
      return t < max_steps

    initial_state = cell.zero_state(1, dtypes.float32)
    _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
        0, initial_state,
        tensor_array_ops.TensorArray(dtypes.float32, max_steps)

    new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
    new_state = nest.pack_sequence_as(initial_state, new_state)
    return ta.stack(), new_state
  def testWhileLoopWithSingleVariable(self):
    i = constant_op.constant(0)
    c = lambda i: math_ops.less(i, 10)
    b = lambda i: math_ops.add(i, 1)
    r = control_flow_ops.while_loop(c, b, [i])

    self.assertEqual(self.evaluate(r), 10)
    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        self.bounder_state_h0 = tf.zeros([batch_size, self.units])

        input_x = tf.transpose(inputs, [2, 3, 0, 1])
        input_x = tf.reshape(input_x, [-1,])
        input_x = tf.split(axis=0, num_or_size_splits=self.text1_maxlen * self.text2_maxlen, value=input_x)
        inputs_ta = tf.TensorArray(dtype=tf.float32, size=self.text1_maxlen * self.text2_maxlen, name='input_ta')
        states_ta = tf.TensorArray(dtype=tf.float32, size=(self.text1_maxlen + 1) * (self.text2_maxlen + 1),
                                   name='state_ta', clear_after_read=False)

        for i in range(self.text2_maxlen + 1):
            states_ta = states_ta.write(i, self.bounder_state_h0)
        for i in range(self.text1_maxlen):
            states_ta = states_ta.write((i + 1) * (self.text2_maxlen + 1), self.bounder_state_h0)
        inputs_ta = inputs_ta.unstack(input_x)
        _, _, _, hij, _ = control_flow_ops.while_loop(
            cond=lambda _0, _1, i, _3, _4: i < self.recurrent_step,
                inputs_ta, states_ta, tf.Variable(0, dtype=tf.int32), self.bounder_state_h0, self.bounder_state_h0),
        return hij
  def _v1_nested_while_saved_model(self):
    export_graph = ops.Graph()
    with export_graph.as_default():

      def _inner_while(loop_iterations):
        _, output = control_flow_ops.while_loop(
            lambda index, accum: index <= loop_iterations,
            lambda index, accum: (index + 1, accum + index),
            [constant_op.constant(0), constant_op.constant(0)])
        return output

      loop_iterations = array_ops.placeholder(
          name="loop_iterations", shape=[], dtype=dtypes.int32)
      _, output = control_flow_ops.while_loop(
          lambda index, accum: index <= loop_iterations,
          lambda index, accum: (index + 1, accum + _inner_while(index)),
          [constant_op.constant(0), constant_op.constant(0)])
      with session_lib.Session() as session:
        path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
            inputs={"loop_iterations": loop_iterations},
            outputs={"output": output})
    return path
    def body(it, cost):
      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      cost = control_flow_ops.cond(
          math_ops.equal(it, 3), lambda: math_ops.square(cost),
          (lambda: cost + math_ops.reduce_sum(embedding)))
      return it + 1, cost

      _, cost = control_flow_ops.while_loop(
          cond, body, [constant_op.constant(0),

      dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
      dynamic_grads = math_ops.segment_sum(dynamic_grads.values,

      embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
      static = math_ops.square(
          math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
          math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
      static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
      static_grads = math_ops.segment_sum(static_grads.values,

      with self.cached_session():
        self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
  def testScanInsideWhile(self):

    def loop_cond(idx_step, *unused_args):
      return idx_step < 1

    def loop_body(idx_step, y):
      x = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
      x = functional_ops.scan(
          initializer=array_ops.zeros([20, 30], dtype=dtypes.float32),

      with ops.device('/cpu:0'):
        y = array_ops.identity(x)

        return idx_step + 1, y

    if test.is_gpu_available(cuda_only=True):
      init_y = array_ops.zeros([10, 20, 30], dtype=dtypes.float32)
      _, y = control_flow_ops.while_loop(
          loop_vars=[0, init_y],
      with session.Session() as sess:
        y_v = self.evaluate(y)
        self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
def generator(x_real, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
    start_tokens = tf.constant([start_token] * batch_size, dtype=tf.int32)
    output_size = mem_slots * head_size * num_heads

    # build relation memory module
    g_embeddings = tf.get_variable(
        shape=[vocab_size, gen_emb_dim],
    gen_mem = RelationalMemory(mem_slots=mem_slots,
    g_output_unit = create_output_unit(output_size, vocab_size)

    # initial states
    init_states = gen_mem.initial_state(batch_size)

    # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
    gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
    gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
    gen_x_onehot_adv = tensor_array_ops.TensorArray(
        dtype=tf.float32, size=seq_len, dynamic_size=False,
        infer_shape=True)  # generator output (relaxed of gen_x)

    # the generator recurrent module used for adversarial training
    def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)  # hidden_memory_tuple
        o_t = g_output_unit(mem_o_t)  # batch x vocab, logits not probs
        gumbel_t = add_gumbel(o_t)
        next_token = tf.stop_gradient(
            tf.argmax(gumbel_t, axis=1, output_type=tf.int32))
        next_token_onehot = tf.one_hot(next_token, vocab_size, 1.0, 0.0)

        x_onehot_appr = tf.nn.softmax(tf.multiply(
            gumbel_t, temperature))  # one-hot-like, [batch_size x vocab_size]

        # x_tp1 = tf.matmul(x_onehot_appr, g_embeddings)  # approximated embeddings, [batch_size x emb_dim]
        x_tp1 = tf.nn.embedding_lookup(
            g_embeddings, next_token)  # embeddings, [batch_size x emb_dim]

        gen_o = gen_o.write(i,
                                tf.multiply(next_token_onehot, x_onehot_appr),
                                1))  # [batch_size], prob
        gen_x = gen_x.write(i, next_token)  # indices, [batch_size]

        gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

        return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv

    # build a graph for outputting sequential tokens
    _, _, _, gen_o, gen_x, gen_x_onehot_adv = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3, _4, _5: i < seq_len,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, gen_o, gen_x, gen_x_onehot_adv))

    gen_o = tf.transpose(gen_o.stack(), perm=[1, 0])  # batch_size x seq_len
    gen_x = tf.transpose(gen_x.stack(), perm=[1, 0])  # batch_size x seq_len

    gen_x_onehot_adv = tf.transpose(
        perm=[1, 0, 2])  # batch_size x seq_len x vocab_size

    # ----------- pre-training for generator -----------------
    x_emb = tf.transpose(tf.nn.embedding_lookup(g_embeddings, x_real),
                         perm=[1, 0, 2])  # seq_len x batch_size x emb_dim
    g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,

    ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32, size=seq_len)
    ta_emb_x = ta_emb_x.unstack(x_emb)

    # the generator recurrent moddule used for pre-training
    def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
        mem_o_t, h_t = gen_mem(x_t, h_tm1)
        o_t = g_output_unit(mem_o_t)
        g_predictions = g_predictions.write(
            i, tf.nn.softmax(o_t))  # batch_size x vocab_size
        x_tp1 =
        return i + 1, x_tp1, h_t, g_predictions

    # build a graph for outputting sequential tokens
    _, _, _, g_predictions = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < seq_len,
        loop_vars=(tf.constant(0, dtype=tf.int32),
                   tf.nn.embedding_lookup(g_embeddings, start_tokens),
                   init_states, g_predictions))

    g_predictions = tf.transpose(
        perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

    # pre-training loss
    pretrain_loss = -tf.reduce_sum(
        tf.one_hot(tf.to_int32(tf.reshape(x_real, [-1])), vocab_size, 1.0, 0.0)
        * tf.log(
            tf.clip_by_value(tf.reshape(g_predictions, [-1, vocab_size]),
                             1e-20, 1.0))) / (seq_len * batch_size)

    return gen_x_onehot_adv, gen_x, pretrain_loss, gen_o
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
  """Runs `loop_fn` `iters` times and stacks the outputs.

  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
  stacks corresponding outputs of the different runs.

    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and returns a possibly nested structure of tensor
      objects. The shape of these outputs should not depend on the input.
    loop_fn_dtypes: dtypes for the outputs of `loop_fn`.
    iters: Number of iterations for which to run `loop_fn`.
    parallel_iterations: The number of iterations that can be dispatched in
      parallel. This knob can be used to control the total memory usage.

    Returns a nested structure of stacked output tensor objects with the same
    nested structure as the output of `loop_fn`.

  flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
  is_none_list = []

  def while_body(i, *ta_list):
    """Body of while loop."""
    fn_output = nest.flatten(loop_fn(i))
    if len(fn_output) != len(flat_loop_fn_dtypes):
      raise ValueError(
          "Number of expected outputs, %d, does not match the number of "
          "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
    outputs = []
    del is_none_list[:]
    is_none_list.extend(x is None for x in fn_output)
    for out, ta in zip(fn_output, ta_list):
      # TODO(agarwal): support returning Operation objects from loop_fn.
      if out is not None:
        # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
        ta = ta.write(i, array_ops.expand_dims(out, 0))
    return tuple([i + 1] + outputs)

  if parallel_iterations is not None:
    extra_args = {"parallel_iterations": parallel_iterations}
    extra_args = {}
  ta_list = control_flow_ops.while_loop(
      lambda i, *ta: i < iters,
      [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
             for dtype in flat_loop_fn_dtypes],

  # TODO(rachelim): enable this for sparse tensors

  output = [None if is_none else ta.concat()
            for ta, is_none in zip(ta_list, is_none_list)]
  assert len(output) in (0, len(flat_loop_fn_dtypes))
  if not output:
    # This may happen for the case where iters == 0.
    return None
    return nest.pack_sequence_as(loop_fn_dtypes, output)
def foldl(fn,
    """foldl on the list of tensors unpacked from `elems` on dimension 0.

  This foldl operator repeatedly applies the callable `fn` to a sequence
  of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn, and the second is the value at the current
  position of `elems`. If `initializer` is None, `elems` must contain at least
  one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is fn(initializer, values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

    fn: The callable to be performed.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      as the initial value for the accumulator.
    parallel_iterations: (optional) The number of iterations allowed to run in
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    name: (optional) Name prefix for the returned tensors.

    A tensor or (possibly nested) sequence of tensors, resulting from applying
    `fn` consecutively to the list of tensors unpacked from `elems`, from first
    to last.

    TypeError: if `fn` is not callable.

    elems = tf.constant([1, 2, 3, 4, 5, 6])
    sum = foldl(lambda a, x: a + x, elems)
    # sum == 21
    if not callable(fn):
        raise TypeError("fn must be callable.")

    def create_ta(elem):
        return tensor_array_ops.TensorArray(dtype=elem.dtype,

    in_graph_mode = not context.executing_eagerly()
    with ops.name_scope(name, "foldl", [elems]):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        # Convert elems to tensor array. n may be known statically.
        elems_flat = [
            ops.convert_to_tensor(elem, name="elem")
            for elem in nest.flatten(elems)
        n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
             or array_ops.shape(elems_flat[0])[0])

        elems_ta = nest.map_structure(create_ta, elems)

        if initializer is None:
            a = nest.map_structure(lambda elem:, elems_ta)
            i = constant_op.constant(1)
            a = initializer
            i = constant_op.constant(0)

        def compute(i, a):
            elem_i = nest.map_structure(lambda elem:, elems_ta)
            a = fn(a, elem_i)
            return [i + 1, a]

        _, r_a = control_flow_ops.while_loop(
            lambda i, a: i < n,
            compute, [i, a],

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:

        return r_a
def scan(fn,
    """scan on the list of tensors unpacked from `elems` on dimension 0.

  See also `tf.map_fn`.

  The simplest version of `scan` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn, and the second is the value at the current
  position of `elems`. If `initializer` is None, `elems` must contain at least
  one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
  If reverse=True, it's fn(initializer, values[-1]).shape.

  This method also allows multi-arity `elems` and accumulator.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The second argument of
  `fn` must match the structure of `elems`.

  If no `initializer` is provided, the output structure and dtypes of `fn`
  are assumed to be the same as its input; and in this case, the first
  argument of `fn` must match the structure of `elems`.

  If an `initializer` is provided, then the output of `fn` must have the same
  structure as `initializer`; and the first argument of `fn` must match
  this structure.

  For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
  `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
  `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
  `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
   one that works in `python3`, is:
  `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.

    fn: The callable to be performed.  It accepts two arguments.  The first will
      have the same structure as `initializer` if one is provided, otherwise it
      will have the same structure as `elems`.  The second will have the same
      (possibly nested) structure as `elems`.  Its output must have the same
      structure as `initializer` if one is provided, otherwise it must have the
      same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be the first argument to `fn`.
    initializer: (optional) A tensor or (possibly nested) sequence of tensors,
      initial value for the accumulator, and the expected output type of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run in
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    reverse: (optional) True scans the tensor last to first (instead of first to
    name: (optional) Name prefix for the returned tensors.

    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, and the previous accumulator value(s), from first to last (or
    last to first, if `reverse=True`).

    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `initializer` do not match.
    ValueError: if the lengths of the output of `fn` and `initializer`
      do not match.

    elems = np.array([1, 2, 3, 4, 5, 6])
    sum = scan(lambda a, x: a + x, elems)
    # sum == [1, 3, 6, 10, 15, 21]
    sum = scan(lambda a, x: a + x, elems, reverse=True)
    # sum == [21, 20, 18, 15, 11, 6]

    elems = np.array([1, 2, 3, 4, 5, 6])
    initializer = np.array(0)
    sum_one = scan(
        lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
    # sum_one == [1, 2, 3, 4, 5, 6]

    elems = np.array([1, 0, 0, 0, 0, 0])
    initializer = (np.array(0), np.array(1))
    fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
    # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
    if not callable(fn):
        raise TypeError("fn must be callable.")

    input_is_sequence = nest.is_sequence(elems)
    input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

    def input_pack(x):
        return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

    if initializer is None:
        output_is_sequence = input_is_sequence
        output_flatten = input_flatten
        output_pack = input_pack
        output_is_sequence = nest.is_sequence(initializer)
        output_flatten = lambda x: nest.flatten(
            x) if output_is_sequence else [x]

        def output_pack(x):
            return (nest.pack_sequence_as(initializer, x)
                    if output_is_sequence else x[0])

    elems_flat = input_flatten(elems)

    in_graph_mode = not context.executing_eagerly()
    with ops.name_scope(name, "scan", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        # Convert elems to tensor array.
        elems_flat = [
            ops.convert_to_tensor(elem, name="elem") for elem in elems_flat

        # Convert elems to tensor array. n may be known statically.
        n = tensor_shape.dimension_value(elems_flat[0].shape[0])
        if n is None:
            n = array_ops.shape(elems_flat[0])[0]

        # TensorArrays are always flat
        elems_ta = [
            for elem in elems_flat
        # Unpack elements
        elems_ta = [
            for elem_ta, elem in zip(elems_ta, elems_flat)

        if initializer is None:
            a_flat = [ - 1 if reverse else 0) for elem in elems_ta]
            i = 1
            initializer_flat = output_flatten(initializer)
            a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
            i = 0

        # Create a tensor array to store the intermediate values.
        accs_ta = [
                element_shape=init.shape if infer_shape else None,
                infer_shape=infer_shape) for init in a_flat

        if initializer is None:
            accs_ta = [
                acc_ta.write(n - 1 if reverse else 0, a)
                for (acc_ta, a) in zip(accs_ta, a_flat)

        def compute(i, a_flat, tas):
            """The loop body of scan.

        i: the loop counter.
        a_flat: the accumulator value(s), flattened.
        tas: the output accumulator TensorArray(s), flattened.

        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
          updated TensorArrays

        TypeError: if initializer and fn() output structure do not match
        ValueType: if initializer and fn() output lengths do not match
            packed_elems = input_pack(
                [ for elem_ta in elems_ta])
            packed_a = output_pack(a_flat)
            a_out = fn(packed_a, packed_elems)
                elems if initializer is None else initializer, a_out)
            flat_a_out = output_flatten(a_out)
            tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
            if reverse:
                next_i = i - 1
                next_i = i + 1
            return (next_i, flat_a_out, tas)

        if reverse:
            initial_i = n - 1 - i
            condition = lambda i, _1, _2: i >= 0
            initial_i = i
            condition = lambda i, _1, _2: i < n
        _, _, r_a = control_flow_ops.while_loop(
            compute, (initial_i, a_flat, accs_ta),

        results_flat = [r.stack() for r in r_a]

        n_static = tensor_shape.Dimension(
        for elem in elems_flat[1:]:
        for r in results_flat:

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:

        return output_pack(results_flat)
def _dopri5(func,
    """Solve an ODE for `odeint` using method='dopri5'."""
    us0 = us if us is None else us[0]
    if first_step is None:
        first_step = _select_initial_step(ts)

    with ops.name_scope(
            name, 'dopri5',
        [ys0, ts, us, rtol, atol, safety, ifactor, dfactor, max_num_steps
         ]) as scope:
        first_step = ops.convert_to_tensor(first_step,
        safety = ops.convert_to_tensor(safety, dtype=ts.dtype, name='safety')
        ifactor = ops.convert_to_tensor(ifactor,
        dfactor = ops.convert_to_tensor(dfactor,
        max_num_steps = ops.convert_to_tensor(max_num_steps,

        with _assert_monotonicity(ts):
            num_times = array_ops.size(ts)
            first_step = control_flow_ops.cond(
                math_ops.reduce_all(ts[1:] >= ts[:-1]), lambda: first_step,
                lambda: -first_step)

        def adaptive_runge_kutta_step(rk_state, history, n_steps):
            """Take an adaptive Runge-Kutta step to integrate the ODE."""
            ys0, fs0, _, t0, us0, dt, interp_coeff = rk_state
            with ops.name_scope('assertions'):
                check_underflow = control_flow_ops.Assert(
                    (t0 + dt > t0 and first_step > 0)
                    or (t0 + dt < t0 and first_step < 0),
                    ['underflow in dt', dt])
                check_max_num_steps = control_flow_ops.Assert(
                    n_steps < max_num_steps, ['max_num_steps exceeded'])
                check_numerics = _traverse_and_return_flattened(
                    ys0, lambda y, _: control_flow_ops.Assert(
                        ['non-finite values in state `y`', y]))
            with ops.control_dependencies(
                [check_underflow, check_max_num_steps] + check_numerics):
                ys1, fs1, ys1_error, ks = _runge_kutta_step(
                    func, ys0, fs0, t0, us0, dt)

            with ops.name_scope('error_ratio'):
                # We use the same approach as the dopri5 fortran code.
                error_tol = _multi_traverse_and_return_nested(
                    [ys0, ys1],
                    lambda y0, y1, _: atol + rtol * math_ops.maximum(
                        abs(y0), abs(y1)))
                tensor_error_ratio = _multi_traverse_and_return_nested(
                    [ys1_error, error_tol],
                    lambda err, tol, _: _abs_square(err) / _abs_square(tol))
                # Could also use reduce_maximum here.
                error_ratio = math_ops.sqrt(
                            lambda err, _: math_ops.reduce_mean(err))))
                accept_step = error_ratio <= 1

            with ops.name_scope('update/rk_state'):
                # If we don't accept the step, the _RungeKuttaState will be useless
                # (covering a time-interval of size 0), but that's OK, because in such
                # cases we always immediately take another Runge-Kutta step.
                ys_next = control_flow_ops.cond(accept_step, lambda: ys1,
                                                lambda: ys0)
                fs_next = control_flow_ops.cond(accept_step, lambda: fs1,
                                                lambda: fs0)
                ts_next = control_flow_ops.cond(accept_step, lambda: t0 + dt,
                                                lambda: t0)
                us_next = us0
                interp_coeff = control_flow_ops.cond(
                    accept_step, lambda: _interp_fit_rk(ys0, ys1, ks, dt),
                    lambda: interp_coeff)
                dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor,
                rk_state = _RungeKuttaState(ys_next, fs_next, t0, ts_next,
                                            us_next, dt_next, interp_coeff)

            with ops.name_scope('update/history'):
                history = _History(
                    _ta_append(history.integrate_points, t0 + dt),
                    _ta_append(history.error_ratio, error_ratio))
            return rk_state, history, n_steps + 1

        def interpolate(solution, history, rk_state, i):
            """Interpolate through the next time point, integrating as necessary."""
            with ops.name_scope('interpolate'):
                us1 = None if us is None else (
                    us[0] if len(us) == 1 else us[i])
                ys1, fs1, t0, t1, _, dt, interp_coeff = rk_state
                rk_state = _RungeKuttaState(ys1, fs1, t0, t1, us1, dt,

                rk_state, history, _ = control_flow_ops.while_loop(
                    lambda rk_s, *_: (ts[i] > rk_s.t1 and first_step > 0) or
                    (ts[i] < rk_s.t1 and first_step < 0),
                    adaptive_runge_kutta_step, (rk_state, history, 0),

                ys = _interp_evaluate(rk_state.interp_coeff, rk_state.t0,
                                      rk_state.t1, ts[i])
                solution = _multi_traverse_and_return_nested(
                    [solution, ys], lambda sol, y, _: sol.write(i, y))

                return solution, history, rk_state, i + 1

        solution = _traverse_and_return_nested(
            ys0, lambda y, _: tensor_array_ops.TensorArray(
                y.dtype, size=num_times).write(0, y))
        history = _History(
        rk_state = _RungeKuttaState(ys0, func(ys0, ts[0], us0), ts[0], ts[0],
                                    us0, first_step, [ys0] * 5)

        solution, history, _, _ = control_flow_ops.while_loop(
            lambda _, __, ___, i: i < num_times,
            interpolate, (solution, history, rk_state, 1),

        ys = _traverse_and_return_nested(solution,
                                         lambda s, _: s.stack(name=scope))
        _multi_traverse_and_do([ys, ys0], lambda y, y0, _: y.set_shape(
        if not full_output:
            return ys
            integrate_points = history.integrate_points.stack()
            info_dict = {
                'num_func_evals': 6 * array_ops.size(integrate_points) + 1,
                'integrate_points': integrate_points,
                'error_ratio': history.error_ratio.stack()
            return (ys, info_dict)
    def __init__(self, num_emb, batch_size, emb_dim, hidden_dim, sequence_length, start_token, params):
        self.num_emb = num_emb
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32)
        self.g_params = []
        self.temperature = 1.0
        self.params = params


        with tf.variable_scope('generator'):
            self.g_embeddings = tf.Variable(self.params[0])
            self.g_recurrent_unit = self.create_recurrent_unit(self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition
        self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length]) # sequence of tokens generated by generator

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x), perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # initial states
        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        # generator on initial randomness
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
            gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0),
                                                             tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x)

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 =
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token),
                       self.h0, g_predictions))

        self.g_predictions = tf.transpose(
            self.g_predictions.stack(), perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                tf.reshape(self.g_predictions, [-1, self.num_emb]))) / (self.sequence_length * self.batch_size)

        self.out_loss = tf.reduce_sum(
                    tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                        tf.reshape(self.g_predictions, [-1, self.num_emb])), 1
                ), [-1, self.sequence_length]
            ), 1
        )  # batch_size
    def __init__(self, num_vocabulary, batch_size, emb_dim, hidden_dim,
                 sequence_length, start_token,
                 learning_rate=0.01, reward_gamma=0.95):
        self.num_vocabulary = num_vocabulary
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.constant([start_token] * self.batch_size, dtype=tf.int32)
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []
        self.temperature = 1.0
        self.grad_clip = 5.0
        self.alpha = 1.0

        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))

        with tf.variable_scope('generator'):
            self.g_embeddings = tf.Variable(self.init_matrix([self.num_vocabulary, self.emb_dim]))
            self.g_recurrent_unit = self.create_recurrent_unit(self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(self.g_params)  # maps h_t to o_t (output token logits)

        # placeholder definition
        self.x = tf.placeholder(tf.int32, shape=[self.batch_size,
                                                 self.sequence_length])  # sequence of tokens generated by generator
        self.rewards = tf.placeholder(tf.float32, shape=[self.batch_size,
                                                         self.sequence_length])  # get from rollout policy and discriminator

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(tf.nn.embedding_lookup(self.g_embeddings, self.x),
                                            perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # Initial states
        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
            gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_vocabulary, 1.0, 0.0),
                                                             tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o, gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length

        ########## temp sweep ###########

        gen_o_temp = tensor_array_ops.TensorArray(dtype=tf.float32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)
        gen_x_temp = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence_temperature(i, x_t, h_tm1, gen_o_temp, gen_x_temp, alpha):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)/alpha  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
            gen_o_temp = gen_o_temp.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_vocabulary, 1.0, 0.0),
                                                                       tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_x_temp = gen_x_temp.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o_temp, gen_x_temp, alpha

        _, _, _, self.gen_o_temp, self.gen_x_temp, _ = control_flow_ops.while_loop(
            cond=lambda j, _1, _2, _3, _4, _5: j < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, gen_o_temp, gen_x_temp, self.alpha))
        self.gen_x_temp = self.gen_x_temp.stack()  # seq_length x batch_size
        self.gen_x_temp = tf.transpose(self.gen_x_temp, perm=[1, 0])  # batch_size x seq_length

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 =
            return i + 1, x_tp1, h_t, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token),
                       self.h0, g_predictions))

        self.g_predictions = tf.transpose(self.g_predictions.stack(),
                                          perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
            tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_vocabulary, 1.0, 0.0) * tf.log(
                tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_vocabulary]), 1e-20, 1.0)
        ) / (self.sequence_length * self.batch_size)

        # training updates
        pretrain_opt = self.g_optimizer(self.learning_rate)

        self.pretrain_grad, _ = tf.clip_by_global_norm(tf.gradients(self.pretrain_loss, self.g_params), self.grad_clip)
        self.pretrain_updates = pretrain_opt.apply_gradients(zip(self.pretrain_grad, self.g_params))

        #  Unsupervised Training
        self.g_loss = -tf.reduce_sum(
                tf.one_hot(tf.to_int32(tf.reshape(self.x, [-1])), self.num_vocabulary, 1.0, 0.0) * tf.log(
                    tf.clip_by_value(tf.reshape(self.g_predictions, [-1, self.num_vocabulary]), 1e-20, 1.0)
                ), 1) * tf.reshape(self.rewards, [-1])

        g_opt = self.g_optimizer(self.learning_rate)

        self.g_grad, _ = tf.clip_by_global_norm(tf.gradients(self.g_loss, self.g_params), self.grad_clip)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))
def map_fn(fn,
  """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)

  If, however, the function is not expressible as a TensorFlow op, then use

  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)


  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  def func(tensor):
    return tf.map_fn(fn, tensor)

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `dtype` if one is provided, otherwise it must have the
      same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`. Use
      `RaggedTensorType` to declare an output of type `RaggedTensor`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

    A possibly nested sequence of potentially ragged tensors.  Each
    tensor packs the results of applying `fn` to tensors unpacked from `elems`
    along the first dimension, from first to last.

    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  #### Examples:

    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]

    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]

    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]

    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
    mean = map_fn(tf.reduce_mean, elems)
    # mean == [2, 4, 6]

    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64)
    out = map_fn(fn=lambda x: x+1, elems,
      dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0))
    # out = ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
  if not callable(fn):
    raise TypeError("fn must be callable.")

  if isinstance(elems, sparse_tensor.SparseTensor):
    raise TypeError(
        "To perform a map on the values of a sparse tensor use either "
        " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
        " SparseTensor(input.indices, map_fn(fn, input.values), "

  in_graph_mode = not context.executing_eagerly()
  # Set the default number of parallel_iterations depending on graph/eager mode.
  if in_graph_mode and not parallel_iterations:
    parallel_iterations = 10
  elif not in_graph_mode and not parallel_iterations:
    parallel_iterations = 1

  if not in_graph_mode and parallel_iterations > 1:
    logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
                        "effect when executing eagerly. Consider calling map_fn"
                        " with tf.contrib.eager.defun to execute fn in "
                        "parallel.", 1)
    parallel_iterations = 1

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  elems_flat = input_flatten(elems)
  elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)

  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem")
        for elem in elems_flat

    # We can either infer the output, or we can assume that it will be the same
    # as the input structure.
    dtype = dtype or input_pack([elem.dtype for elem in elems_flat])

    # Find the number of iterations, n may be known statically.
    if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
      n = elems_flat[0].nrows(out_type=dtypes.int32)
      static_shape = elems_flat[0].shape
      if static_shape.ndims is not None and static_shape.ndims < 1:
        if len(elems_flat) == 1:
          raise ValueError(
              "elems must be a 1+ dimensional Tensor, not a scalar")
          raise ValueError(
              "elements in elems must be 1+ dimensional Tensors, not scalars")
      n = (tensor_shape.dimension_value(static_shape[0]) or

    n = math_ops.cast(n, dtype=dtypes.int32)
    # Create a flat list of TAs.

    # Flatten the dtype structure to a list.
    dtype_flat = nest.flatten(dtype)

    # decompose to components
    dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
    dtype_components_flat = nest.flatten(dtype_components)

    # Create TensorArrays.
    accs_ta = [
            dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n)
        for t in dtype_components_flat

    i = constant_op.constant(0, dtype=dtypes.int32)

    def compute(i, tas):
      """The loop body of map_fn.

        i: the loop counter
        tas: the flat TensorArray accumulator list

        (i + 1, tas): the updated counter + updated TensorArrays

        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      # Get Tensors or RaggedTensors sliced at i, then pack it back to the
      # original structure.
      packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat])
      packed_fn_values = fn(packed_values)

      # Check that the structure of the output matches what was declared or
      # inferred.
      # nest.assert_same_structure(dtype or elems, packed_fn_values)

      # Flatten and decompose to a list of Tensors
      flat_fn_values = nest.flatten(packed_fn_values)

      # If we declared that we are expecting a RaggedTensor output, but we get a
      # Tensor output. We should try to convert it to a RaggedTensor.
      flat_fn_composite_tensors = list(
          _convert_declared(flat_fn_values, dtype_flat))

      flat_fn_components = [
          _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
      flat_fn_tensors = nest.flatten(flat_fn_components)

      # Write to TAs.
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)]

      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n, compute, (i, accs_ta),

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:

    # Pack back into a list of components
    results_as_components = nest.pack_sequence_as(dtype_components, r_a)

    # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
    def _stack_or_concat(e):
      if isinstance(e, _RaggedTensorComponents):
        return _concat_ragged_tensor_components(e)
        result = e.stack()
        return result

    results_flat_components = [
        _stack_or_concat(e) for e in results_as_components

    results_packed = [
        _maybe_recompose_tensor(c) for c in results_flat_components
    results_packed = nest.pack_sequence_as(dtype, results_packed)
    return results_packed
    def _testScopedExport(self, test_dir, exported_filenames):
        graph = ops.Graph()
        with graph.as_default():
            # Creates an inference graph.
            # Hidden 1
            colocate_constraint = constant_op.constant(1.2, name="constraint")
            images = constant_op.constant(1.2,
                                          shape=[100, 28],
            with ops.name_scope("hidden1"):
                with graph.colocate_with(colocate_constraint.op):
                    weights1 = variables.Variable(random_ops.truncated_normal(
                        [28, 128], stddev=1.0 / math.sqrt(float(28))),
                # The use of control_flow_ops.cond here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                biases1 = variables.Variable(control_flow_ops.cond(
                                  0.5), lambda: array_ops.ones([128]),
                    lambda: array_ops.zeros([128])),
                hidden1 = nn_ops.relu(
                    math_ops.matmul(images, weights1) + biases1)

            # Hidden 2
            with ops.name_scope("hidden2"):
                weights2 = variables.Variable(random_ops.truncated_normal(
                    [128, 32], stddev=1.0 / math.sqrt(float(128))),

                # The use of control_flow_ops.while_loop here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                def loop_cond(it, _):
                    return it < 2

                def loop_body(it, biases2):
                    biases2 += constant_op.constant(0.1, shape=[32])
                    return it + 1, biases2

                _, biases2 = control_flow_ops.while_loop(
                    loop_cond, loop_body, [
                hidden2 = nn_ops.relu(
                    math_ops.matmul(hidden1, weights2) + biases2)
            # Linear
            with ops.name_scope("softmax_linear"):
                weights3 = variables.Variable(random_ops.truncated_normal(
                    [32, 10], stddev=1.0 / math.sqrt(float(32))),
                biases3 = variables.Variable(array_ops.zeros([10]),
                logits = math_ops.matmul(hidden2, weights3) + biases3
                ops.add_to_collection("logits", logits)

            # Exports each sub-graph.
            # Exports the first one with unbound_inputs_col_name set to default.
            orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[0]),
            self.assertEqual(["biases:0", "weights:0"],
            var_names = [ for _, v in var_list.items()]
            self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],

            # Exports the rest with no unbound_inputs_col_name.
            orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[1]),
            orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filenames[2]),

        return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
def map_fn(fn,
  """Transforms `elems` by applying `fn` to each element unstacked on axis 0.

  See also `tf.scan`.

  `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
  calls `fn` to transform each element; and then stacks the transformed
  values back together.

  #### Mapping functions with single-Tensor inputs and outputs

  If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
  then `map_fn(fn, elems)` is equivalent to
  `tf.stack([fn(elem) for elem in tf.unstack(elems)])`.  E.g.:

  >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.

  #### Mapping functions with multi-arity inputs and outputs

  `map_fn` also supports functions with multi-arity inputs and outputs:

  * If `elems` is a tuple (or nested structure) of tensors, then those tensors
    must all have the same outer-dimension size (`num_elems`); and `fn` is
    used to transform each tuple (or structure) of corresponding slices from
    `elems`.  E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
    transform each tuple of slices `(t1[i], t2[i], t3[i])`
    (where `0 <= i < num_elems`).

  * If `fn` returns a tuple (or nested structure) of tensors, then the
    result is formed by stacking corresponding elements from those structures.

  #### Specifying `fn`'s output signature

  If `fn`'s input and output signatures are different, then the output
  signature must be specified using `fn_output_signature`.  (The input and
  output signatures are differ if their structures, dtypes, or tensor types do
  not match).  E.g.:

  >>> tf.map_fn(fn=tf.strings.length,  # input & output have different dtypes
  ...           elems=tf.constant(["hello", "moon"]),
  ...           fn_output_signature=tf.int32)
  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
  >>> tf.map_fn(fn=tf.strings.join,  # input & output have different structures
  ...           elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
  ...           fn_output_signature=tf.string)
  <tf.Tensor: shape=(2,), dtype=string,
   numpy=array([b'TheDog', b'ACat'], dtype=object)>

  `fn_output_signature` can be specified using any of the following:

  * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
  * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
  * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
  * A (possibly nested) tuple, list, or dict containing the above types.

  #### RaggedTensors

  `map_fn` supports `tf.RaggedTensor` inputs and outputs.  In particular:

  * If `elems` is a `RaggedTensor`, then `fn` will be called with each
    row of that ragged tensor.
    * If `elems` has only one ragged dimension, then the values passed to
      `fn` will be `tf.Tensor`s.
    * If `elems` has multiple ragged dimensions, then the values passed to
      `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.

  * If the result of `map_fn` should be a `RaggedTensor`, then use a
    `tf.RaggedTensorSpec` to specify `fn_output_signature`.
    * If `fn` returns `tf.Tensor`s with varying sizes, then use a
      `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
      single ragged tensor (which will have ragged_rank=1).
    * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
      with the same `ragged_rank`.

  >>> # Example: RaggedTensor input
  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>

  >>> # Example: RaggedTensor output
  >>> elems = tf.constant([3, 5, 0, 2])
  >>> tf.map_fn(tf.range, elems,
  ...           fn_output_signature=tf.RaggedTensorSpec(shape=[None],
  ...                                                   dtype=tf.int32))
  <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `RaggedTensor`.  If you wish to map a function over the
  individual values, then you should use:

  * `tf.ragged.map_flat_values(fn, rt)`
    (if fn is expressible as TensorFlow ops)
  * `rt.with_flat_values(map_fn(fn, rt.flat_values))`


  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
  <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>

  #### SparseTensors

  `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs.  In particular:

  * If `elems` is a `SparseTensor`, then `fn` will be called with each row
    of that sparse tensor. In particular, the value passed to `fn` will be a
    `tf.sparse.SparseTensor` with one fewer dimension than `elems`.

  * If the result of `map_fn` should be a `SparseTensor`, then use a
    `tf.SparseTensorSpec` to specify `fn_output_signature`.  The individual
    `SparseTensor`s returned by `fn` will be stacked into a single
    `SparseTensor` with one more dimension.

  >>> # Example: SparseTensor input
  >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
  >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>

  >>> # Example: SparseTensor output
  >>> tf.sparse.to_dense(
  ...     tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
  ...               fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
  <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
    array([[[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]],
           [[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.]]], dtype=float32)>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `SparseTensor`.  If you wish to map a function over the nonzero
  values, then you should use:

  * If the function is expressible as TensorFlow ops, use:
    tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
  * Otherwise, use:
    tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),

  #### `map_fn` vs. vectorized operations

  `map_fn` will apply the operations used by `fn` to each element of `elems`,
  resulting in `O(elems.shape[0])` total operations.  This is somewhat
  mitigated by the fact that `map_fn` can process elements in parallel.
  However, a transform expressed using `map_fn` is still typically less
  efficient than an equivalent transform expressed using vectorized operations.

  `map_fn` should typically only be used if one of the following is true:

  * It is difficult or expensive to express the desired transform with
    vectorized operations.
  * `fn` creates large intermediate values, so an equivalent vectorized
    transform would take too much memory.
  * Processing elements in parallel is more efficient than an equivalent
    vectorized transform.
  * Efficiency of the transform is not critical, and using `map_fn` is
    more readable.

  E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
  across `elems` could be rewritten more efficiently using vectorized ops:

  >>> elems = tf.constant([3, 5, 2])
  >>> tf.range(3) + tf.expand_dims(elems, 1)
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  In some cases, `tf.vectorized_map` can be used to automatically convert a
  function to a vectorized eqivalent.

  #### Eager execution

  When executing eagerly, `map_fn` does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.function` decorator:

  >>> fn=lambda t: tf.range(t, t + 3)
  >>> @tf.function
  ... def func(elems):
  ...   return tf.map_fn(fn, elems, parallel_iterations=3)
  >>> func(tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  Note: if you use the `tf.function` decorator, any non-TensorFlow Python
  code that you may have written in your function won't get executed. See
  `tf.function` for more  details. The recommendation would be to debug without
  `tf.function` but switch to it to get performance benefits of running `map_fn`
  in parallel.

    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `fn_output_signature` if one is provided; otherwise it
      must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unstacked along their first dimension.  `fn` will be applied to the
      nested sequence of the resulting slices.  `elems` may include ragged and
      sparse tensors. `elems` must consist of at least one tensor.
    dtype: Deprecated: Equivalent to `fn_output_signature`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) False disables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.
    fn_output_signature: The output signature of `fn`. Must be specified if
      `fn`'s input and output signatures are different (i.e., if their
      structures, dtypes, or tensor types do not match).
      `fn_output_signature` can be specified using any of the following:

      * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
      * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
      * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
      * A (possibly nested) tuple, list, or dict containing the above types.

    A tensor or (possibly nested) sequence of tensors.  Each tensor stacks the
    results of applying `fn` to tensors unstacked from `elems` along the first
    dimension, from first to last.  The result may include ragged and sparse

    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `fn_output_signature` do not match.
    ValueError: if the lengths of the output of `fn` and `fn_output_signature`
      do not match, or if the `elems` does not contain any tensor.


    >>> elems = np.array([1, 2, 3, 4, 5, 6])
    >>> tf.map_fn(lambda x: x * x, elems)
    <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1,  4,  9, 16, 25, 36])>

    >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
    <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1,  2, -3])>

    >>> elems = np.array([1, 2, 3])
    >>> tf.map_fn(lambda x: (x, -x), elems,
    ...          fn_output_signature=(tf.int64, tf.int64))
    (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
  # This function uses a `while_loop` to call `fn` on each value of the input
  # tensor(s) (unstacked on dimension 0).  The following sequence of variables
  # are used to transform the input tensor(s) (`elems`) into the output
  # tensor(s) (`result`):
  #   - Preparing and unstacking input values for the while_loop:
  #     - elems: The input tensor(s) to map_fn. May include composite tensors.
  #     - elems_flat: Flattened list of tensors from elems (using nest.flatten)
  #                   May include composite tensors.
  #     - elems_batchable: Concatenation of "batchable tensor lists" for each
  #                        tensor in elems_flat.  This "boxes" composite tensors
  #                        into sliceable tf.Tensor objects.  For more info see:
  #                        TensorSpec._to_batched_tensor_list
  #     - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
  #                           in elems_batchable into elems_value_batchable.
  #   - Calling `fn` on each unstacked value in the body of the while_loop:
  #     - elems_value_batchable: Single unstacked value from elems_batchable.
  #     - elems_value_flat: Single unstacked value from elems_flat,
  #                         constructed from elems_value_batchable (using
  #                         TensorSpec._from_tensor_list).
  #     - elems_value: Single unstacked value from elems (the input to fn).
  #     - result_value: Result of calling `fn(elems_value)`.  May contain
  #                     composite tensors.
  #     - result_value_flat: Flattened list of tensors from result_value.
  #                          May contain composite tensors.
  #     - result_value_batchable: Concatenation of batchable tensor lists for
  #                               each tensor in result_value_flat
  #                               (using TensorSpec._to_tensor_list).
  #   - Collecting and stacking output values from the while_loop:
  #     - result_batchable_ta: List of TensorArrays used to stack each tensor
  #                            ta result_value_batchable into result_batchable.
  #     - result_batchable: Stacked tensors from result_batchable_ta.
  #     - result_flat: Flat list of tensors for the result, constructed from
  #                    results bactchable (using TensorSpec._from_tensor_list).
  #     - result: Structured result value packed from results flat
  #               (using nest.pack_sequence_as).

  if fn_output_signature is None:
    fn_output_signature = dtype

  if not callable(fn):
    raise TypeError("fn must be callable.")

  in_graph_mode = not context.executing_eagerly()
  # Set the default number of parallel_iterations depending on graph/eager mode.
  if in_graph_mode and not parallel_iterations:
    parallel_iterations = 10
  elif not in_graph_mode and not parallel_iterations:
    parallel_iterations = 1
  elif not in_graph_mode and parallel_iterations > 1:
        logging.WARN, "Setting parallel_iterations > 1 has no "
        "effect when executing eagerly. Consider calling map_fn"
        " with tf.function to execute fn in "
        "parallel.", 1)
    parallel_iterations = 1

  # Flatten the input tensors, and get the TypeSpec for each one.
  elems_flat = nest.flatten(elems)

  # Check in case this is an empty list
  if len(elems_flat) == 0:
    raise ValueError(
        "elems must be a Tensor or (possibly nested) sequence of Tensors. "
        "Got {}, which does not contain any Tensors.".format(elems))

  elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
  elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)

  # Flatten fn's output signature.
  if fn_output_signature is None:
    # If fn_output_signature was not specified, then assume that it matches the
    # input signature.
    result_flat_signature = [
        _most_general_compatible_type(s)._unbatch()  # pylint: disable=protected-access
        for s in elems_flat_signature
    result_unflatten = elems_unflatten
    result_flat_signature = [
        _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
    result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)

  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat

    # Check that inputs are not scalars.
    first_elem = elems_flat[0]
    if isinstance(first_elem, np_arrays.ndarray):
      first_elem =
    elems_static_shape = first_elem.shape
    if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
      if len(elems_flat) == 1:
        raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
        raise ValueError(
            "elements in elems must be 1+ dimensional Tensors, not scalars"

    # Box any composite tensors into tensor lists.
    elems_batchable = _elems_flat_to_batchable(elems_flat)

    # Find the number of iterations, n.  (may be known statically.)
    n_static = tensor_shape.Dimension(
    for tensor in elems_batchable[1:]:
    n = n_static.value or array_ops.shape(elems_batchable[0])[0]

    # Convert elems to tensor array.
    # TODO(edloper): Should we set infer_shape=False for composite tensors?
    elems_batchable_ta = [
            dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
        for t in elems_batchable
    # Unpack elements
    elems_batchable_ta = [
        ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)

    i = constant_op.constant(0)

    # Prepare result tensor array.
    # TODO(edloper): Should we set infer_shape=False for composite tensors?
    result_batchable_tensor_spec = (
    result_batchable_ta = []
    for spec in result_batchable_tensor_spec:
              dtype=spec.dtype, size=n, dynamic_size=False,
              infer_shape=infer_shape, element_shape=spec.shape))

    def compute(i, tas):
      """The loop body of map_fn.

        i: the loop counter
        tas: the flat TensorArray accumulator list

        (i + 1, tas): the updated counter + updated TensorArrays

        TypeError: if fn_output_signature and result_value structure don't match
        ValueType: if fn_output_signature and result_value lengths don't match
      elems_value_batchable = [ for ta in elems_batchable_ta]
      elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
      elems_value = elems_unflatten(elems_value_flat)
      ag_ctx = autograph_ctx.control_status_ctx()
      autographed_fn = autograph.tf_convert(fn, ag_ctx)
      result_value = autographed_fn(elems_value)
      nest.assert_same_structure(fn_output_signature or elems, result_value)
      result_value_flat = nest.flatten(result_value)
      result_value_batchable = _result_value_flat_to_batchable(
          result_value_flat, result_flat_signature)
      tas = [
          ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n,
        compute, (i, result_batchable_ta),
    result_batchable = [r.stack() for r in r_a]

    # Update each output tensor w/ static shape info about the outer dimension.
    for r in result_batchable:

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:

    result_flat = _result_batchable_to_flat(result_batchable,
    result = result_unflatten(result_flat)
    return result
def raw_rnn(cell,
    """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.

    **NOTE: This method is still in testing, and the API may change.**

    This function is a more primitive version of `dynamic_rnn` that provides
    more direct access to the inputs each iteration.  It also provides more
    control over when to start and finish reading the sequence, and
    what to emit for the output.

    For example, it can be used to implement the dynamic decoder of a seq2seq

    Instead of working with `Tensor` objects, most operations work with
    `TensorArray` objects directly.

    The operation of `raw_rnn`, in pseudo-code, is basically the following:

    time = tf.constant(0, dtype=tf.int32)
    (finished, next_input, initial_state, _, loop_state) = loop_fn(
        time=time, cell_output=None, cell_state=None, loop_state=None)
    emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
    state = initial_state
    while not all(finished):
      (output, cell_state) = cell(next_input, state)
      (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
          time=time + 1, cell_output=output, cell_state=cell_state,
      # Emit zeros and copy forward state for minibatch entries that are finished.
      state =, state, next_state)
      emit =, tf.zeros_like(emit), emit)
      emit_ta = emit_ta.write(time, emit)
      # If any new minibatch entries are marked as finished, mark these
      finished = tf.logical_or(finished, next_finished)
      time += 1
    return (emit_ta, state, loop_state)

    with the additional properties that output and state may be (possibly nested)
    tuples, as determined by `cell.output_size` and `cell.state_size`, and
    as a result the final `state` and `emit_ta` may themselves be tuples.

    A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:

    inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
    sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
    inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
    inputs_ta = inputs_ta.unpack(inputs)

    cell = tf.nn.rnn_cell.LSTMCell(num_units)

    def loop_fn(time, cell_output, cell_state, loop_state):
      emit_output = cell_output  # == None for time == 0
      if cell_output is None:  # time == 0
        next_cell_state = cell.zero_state(batch_size, tf.float32)
        next_cell_state = cell_state
      elements_finished = (time >= sequence_length)
      finished = tf.reduce_all(elements_finished)
      next_input = tf.cond(
          lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
      next_loop_state = None
      return (elements_finished, next_input, next_cell_state,
              emit_output, next_loop_state)

    outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
    outputs = outputs_ta.pack()

      cell: An instance of RNNCell.
      loop_fn: A callable that takes inputs
        `(time, cell_output, cell_state, loop_state)`
        and returns the tuple
        `(finished, next_input, next_cell_state, emit_output, next_loop_state)`.
        Here `time` is an int32 scalar `Tensor`, `cell_output` is a
        `Tensor` or (possibly nested) tuple of tensors as determined by
        `cell.output_size`, and `cell_state` is a `Tensor`
        or (possibly nested) tuple of tensors, as determined by the `loop_fn`
        on its first call (and should match `cell.state_size`).
        The outputs are: `finished`, a boolean `Tensor` of
        shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
        `next_cell_state`: the next state to feed to `cell`,
        and `emit_output`: the output to store for this iteration.

        Note that `emit_output` should be a `Tensor` or (possibly nested)
        tuple of tensors with shapes and structure matching `cell.output_size`
        and `cell_output` above.  The parameter `cell_state` and output
        `next_cell_state` may be either a single or (possibly nested) tuple
        of tensors.  The parameter `loop_state` and
        output `next_loop_state` may be either a single or (possibly nested) tuple
        of `Tensor` and `TensorArray` objects.  This last parameter
        may be ignored by `loop_fn` and the return value may be `None`.  If it
        is not `None`, then the `loop_state` will be propagated through the RNN
        loop, for use purely by `loop_fn` to keep track of its own state.
        The `next_loop_state` parameter returned may be `None`.

        The first call to `loop_fn` will be `time = 0`, `cell_output = None`,
        `cell_state = None`, and `loop_state = None`.  For this call:
        The `next_cell_state` value should be the value with which to initialize
        the cell's state.  It may be a final state from a previous RNN or it
        may be the output of `cell.zero_state()`.  It should be a
        (possibly nested) tuple structure of tensors.
        If `cell.state_size` is an integer, this must be
        a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
        If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of
        appropriate type and shape `[batch_size] + cell.state_size`.
        If `cell.state_size` is a (possibly nested) tuple of ints or
        `TensorShape`, this will be a tuple having the corresponding shapes.
        The `emit_output` value may be  either `None` or a (possibly nested)
        tuple structure of tensors, e.g.,
        `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`.
        If this first `emit_output` return value is `None`,
        then the `emit_ta` result of `raw_rnn` will have the same structure and
        dtypes as `cell.output_size`.  Otherwise `emit_ta` will have the same
        structure, shapes (prepended with a `batch_size` dimension), and dtypes
        as `emit_output`.  The actual values returned for `emit_output` at this
        initializing call are ignored.  Note, this emit structure must be
        consistent across all time steps.

      parallel_iterations: (Default: 32).  The number of iterations to run in
        parallel.  Those operations which do not have any temporal dependency
        and can be run in parallel, will be.  This parameter trades off
        time for space.  Values >> 1 use more memory but take less time,
        while smaller values use less memory but computations take longer.
      swap_memory: Transparently swap the tensors produced in forward inference
        but needed for back prop from GPU to CPU.  This allows training RNNs
        which would typically not fit on a single GPU, with very minimal (or no)
        performance penalty.
      scope: VariableScope for the created subgraph; defaults to "RNN".

      A tuple `(emit_ta, final_state, final_loop_state)` where:

      `emit_ta`: The RNN output `TensorArray`.
         If `loop_fn` returns a (possibly nested) set of Tensors for
         `emit_output` during initialization, (inputs `time = 0`,
         `cell_output = None`, and `loop_state = None`), then `emit_ta` will
         have the same structure, dtypes, and shapes as `emit_output` instead.
         If `loop_fn` returns `emit_output = None` during this call,
         the structure of `cell.output_size` is used:
         If `cell.output_size` is a (possibly nested) tuple of integers
         or `TensorShape` objects, then `emit_ta` will be a tuple having the
         same structure as `cell.output_size`, containing TensorArrays whose
         elements' shapes correspond to the shape data in `cell.output_size`.

      `final_state`: The final cell state.  If `cell.state_size` is an int, this
        will be shaped `[batch_size, cell.state_size]`.  If it is a
        `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
        If it is a (possibly nested) tuple of ints or `TensorShape`, this will
        be a tuple having the corresponding shapes.

      `final_loop_state`: The final loop state as returned by `loop_fn`.

      TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
        a `callable`.

    if not isinstance(cell, rnn_cell.RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "RNN") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input,
         initial_state, emit_structure, init_loop_state) = loop_fn(
             time, None, None,
             None)  # time, cell_output, cell_state, loop_state
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None else
                      constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match

        batch_size = static_batch_size.value
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [
                array_ops.shape(emit) for emit in flat_emit_structure
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
            raise ValueError('emit_structure is None')

        flat_emit_ta = [
                                         name="rnn_output_%d" % i)
            for i, dtype_i in enumerate(flat_emit_dtypes)
        emit_ta = nest.pack_sequence_as(structure=emit_structure,

        flat_zero_emit = [
            array_ops.zeros(size_i, dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)

        zero_emit = nest.pack_sequence_as(structure=emit_structure,

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, emit_ta, state,
            """Internal while loop body for raw_rnn.

              time: time scalar.
              elements_finished: batch-size vector.
              current_input: possibly nested tuple of input tensors.
              emit_ta: possibly nested tuple of output TensorArrays.
              state: possibly nested tuple of state tensors.
              loop_state: possibly nested tuple of loop state tensors.

              Tuple having the same size as Args but with updated values.

            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                result_flat = [
          , current_i, candidate_i)
                    for (current_i,
                         candidate_i) in zip(current_flat, candidate_flat)
                return nest.pack_sequence_as(structure=current,

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished,

            emit_ta_flat = [
                ta.write(time, emit)
                for (ta, emit) in zip(emit_ta_flat, emit_output_flat)

            emit_ta = nest.pack_sequence_as(structure=emit_structure,

            return (next_time, elements_finished, next_input, emit_ta,
                    next_state, loop_state)

        returned = control_flow_ops.while_loop(
                time, elements_finished, next_input, emit_ta, state, loop_state

        (emit_ta, final_state, final_loop_state) = returned[-3:]

        if init_loop_state is None:
            final_loop_state = None

        return (emit_ta, final_state, final_loop_state)
 def f():
     return control_flow_ops.while_loop(
         lambda i, _: i < 5, lambda i, t: (i + 1, all_reduce_sum(t)),
         (array_ops.zeros([]), constant_op.constant(1.0)))
  def _call_for_each_tower(self, fn, *args, **kwargs):
    kwargs.pop('run_concurrently', None)

    inputs = {'args': args, 'kwargs': kwargs}
    flat_inputs = nest.flatten(inputs)

    feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs]

    feeds = lambda: itertools.compress(flat_inputs, feed_mask)
    shapes = [f.get_shape() for f in feeds()]
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = [f.get_dtype() for f in feeds()]

    def infeed_input(i):
      """Get input, split it and then enqueue."""
      iteration_inputs = [f.get(i) for f in feeds()]

      infeed_inputs = [[inputs_per_core[core_id]
                        for inputs_per_core in iteration_inputs]
                       for core_id in range(self._num_cores_per_host)]

      infeed_ops = []
      for core_id, infeed_input in enumerate(infeed_inputs):
                inputs=infeed_input, shapes=shapes, device_ordinal=core_id))

      with ops.control_dependencies(infeed_ops):
        return i + 1

    with ops.device('/task:0/device:CPU:0'):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < self._iterations_per_step,
          infeed_input, [constant_op.constant(0)],

    def dequeueing_fn(*args, **kwargs):
      """Dequeue input arguments and supply them to `fn`."""
      del args, kwargs
      dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      dequeued = iter(dequeued)

      fn_inputs = []
      for inp, is_feed in zip(flat_inputs, feed_mask):
        if is_feed:

      fn_inputs = nest.pack_sequence_as(inputs, fn_inputs)
      return fn(*fn_inputs['args'], **fn_inputs['kwargs'])

    def iterate_on_tpu():
      return tpu.repeat(self._iterations_per_step, dequeueing_fn, [])

    with one_device_strategy._OneDeviceTowerContext(self):  # pylint: disable=protected-access
      tpu_result = tpu.batch_parallel(
          iterate_on_tpu, [], num_shards=self._num_cores_per_host)

    return, enqueue_ops)
    def _run_steps_on_dataset(self,
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = values.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_inputs = iterator.get_next()
            if not isinstance(fn_inputs, tuple):
                fn_inputs = (fn_inputs, )
            fn_result = fn(ctx, *fn_inputs)
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self.unwrap(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  [i] + initial_loop_values,
        del self._outer_control_flow_context

        ctx.run_op =

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been aggregated, wrap them in a Mirrored
            # container, else in a PerDevice container.
            if aggregation is variables_lib.VariableAggregation.NONE:
                last_step_tensor_outputs_dict[name] = values.regroup(
                    {d: t
                     for d, t in zip(self._devices, output)}, values.PerDevice)
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
def dynamic_decode(decoder,
    """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.

    `(final_outputs, final_state, final_sequence_lengths)`.

    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if `maximum_iterations` is provided but is not a scalar.
    if not isinstance(decoder, Decoder):
        raise TypeError("Expected decoder to be type Decoder, but saw: %s" %

    with variable_scope.variable_scope(scope, "decoder") as varscope:
        # Properly cache variable values inside the while_loop
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        initial_finished, initial_inputs, initial_state = decoder.initialize()

        zero_outputs = _create_zero_outputs(decoder.output_size,

        if maximum_iterations is not None:
            initial_finished = math_ops.logical_or(initial_finished,
                                                   0 >= maximum_iterations)
        initial_sequence_lengths = array_ops.zeros_like(initial_finished,
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        def _shape(batch_size, from_shape):
            if not isinstance(from_shape, tensor_shape.TensorShape):
                return tensor_shape.TensorShape(None)
                batch_size = tensor_util.constant_value(
                    ops.convert_to_tensor(batch_size, name="batch_size"))
                return tensor_shape.TensorShape([batch_size

        def _create_ta(s, d):
            return tensor_array_ops.TensorArray(dtype=d,
                                                    decoder.batch_size, s))

        initial_outputs_ta = nest.map_structure(_create_ta,

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            return math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
                next_finished = math_ops.logical_or(decoder_finished, finished)
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)
            next_sequence_lengths = array_ops.where(
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),

            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,

    return final_outputs, final_state, final_sequence_lengths
def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids,
                                      candidate_ids, margin_multiplier,
  """Find the next centroid that maximizes the loss augmented inference.

  This function is a subroutine called from compute_augmented_facility_locations

    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.
    chosen_ids: 1-D Tensor of current centroid indices.
    candidate_ids: 1-D Tensor of candidate indices.
    margin_multiplier: multiplication constant.
    margin_type: Type of structured margin to use. Default is nmi.

    integer index.
  num_candidates = array_ops.shape(candidate_ids)[0]

  pairwise_distances_chosen = array_ops.gather(pairwise_distances, chosen_ids)
  pairwise_distances_candidate = array_ops.gather(
      pairwise_distances, candidate_ids)
  pairwise_distances_chosen_tile = array_ops.tile(
      pairwise_distances_chosen, [1, num_candidates])

  candidate_scores = -1.0 * math_ops.reduce_sum(
                  array_ops.reshape(pairwise_distances_candidate, [1, -1])
              ], 0),
              keep_dims=True), [num_candidates, -1]),

  nmi_scores = array_ops.zeros([num_candidates])
  iteration = array_ops.constant(0)

  def func_cond(iteration, nmi_scores):
    del nmi_scores  # Unused in func_cond()
    return iteration < num_candidates

  def func_body(iteration, nmi_scores):
    predictions = get_cluster_assignment(
        array_ops.concat([chosen_ids, [candidate_ids[iteration]]], 0))
    nmi_score_i = compute_clustering_score(labels, predictions, margin_type)
    pad_before = array_ops.zeros([iteration])
    pad_after = array_ops.zeros([num_candidates - 1 - iteration])
    # return 1 - NMI score as the structured loss.
    #   because NMI is higher the better [0,1].
    return iteration + 1, nmi_scores + array_ops.concat(
        [pad_before, [1.0 - nmi_score_i], pad_after], 0)

  _, nmi_scores = control_flow_ops.while_loop(
      func_cond, func_body, [iteration, nmi_scores])

  candidate_scores = math_ops.add(
      candidate_scores, margin_multiplier * nmi_scores)

  argmax_index = math_ops.to_int32(
      math_ops.argmax(candidate_scores, dimension=0))

  return candidate_ids[argmax_index]
Beispiel #47
    def __init__(self, lstm, update_rate):
        self.lstm = lstm
        self.update_rate = update_rate

        self.num_emb = self.lstm.num_emb
        self.batch_size = self.lstm.batch_size
        self.emb_dim = self.lstm.emb_dim
        self.hidden_dim = self.lstm.hidden_dim
        self.sequence_length = self.lstm.sequence_length
        self.start_token = tf.identity(self.lstm.start_token)
        self.learning_rate = self.lstm.learning_rate

        self.g_embeddings = tf.identity(self.lstm.g_embeddings)
        self.g_recurrent_unit = self.create_recurrent_unit()  # maps h_tm1 to h_t for generator
        self.g_output_unit = self.create_output_unit()  # maps h_t to o_t (output token logits)

        # placeholder definition
        self.x = tf.placeholder(tf.int32, shape=[self.batch_size, self.sequence_length])
        self.given_num = tf.placeholder(tf.int32)
        # sequence of indices of generated data generated by generator, not including start token

        # processed for batch
        with tf.device("/cpu:0"):
            inputs = tf.split(1, self.sequence_length, tf.nn.embedding_lookup(self.g_embeddings, self.x))
            self.processed_x = tf.pack(
                [tf.squeeze(input_, [1]) for input_ in inputs])  # seq_length x batch_size x emb_dim

        ta_emb_x = tensor_array_ops.TensorArray(
            dtype=tf.float32, size=self.sequence_length)
        ta_emb_x = ta_emb_x.unpack(self.processed_x)

        ta_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length)
        ta_x = ta_x.unpack(tf.transpose(self.x, perm=[1, 0]))

        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.pack([self.h0, self.h0])

        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.sequence_length,
                                             dynamic_size=False, infer_shape=True)

        def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            x_tp1 =
            gen_x = gen_x.write(i,
            return i + 1, x_tp1, h_t, given_num, gen_x

        def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings, next_token)  # batch x emb_dim
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, given_num, gen_x

        i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, given_num, _4: i < given_num,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings, self.start_token), self.h0, self.given_num, gen_x))

        _, _, _, _, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(i, x_t, h_tm1, given_num, self.gen_x))

        self.gen_x = self.gen_x.pack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length
Beispiel #48
    def create_recurrent_adv(self):
        # ---------- generate tokens and approximated one-hot results (Adversarial) ---------
        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
        gen_x_onehot_adv = tensor_array_ops.TensorArray(dtype=tf.float32,
        topicness_values = tensor_array_ops.TensorArray(dtype=tf.float32,
        gen_x_no_lambda = tensor_array_ops.TensorArray(dtype=tf.int32,

        def _gen_recurrence(i, x_t, h_tm1, gen_o, gen_x, gen_x_onehot_adv,
                            lambda_values, gen_x_no_lambda):
            mem_o_t, h_t = self.gen_mem(x_t, h_tm1)  # hidden_memory_tuple
            if self.topic_in_memory and not self.no_topic:
                mem_o_t, h_t = self.gen_mem(
                    self.g_topic_embedding(self.x_topic), h_t)
            o_t = self.g_output_unit(mem_o_t)  # batch x vocab, logits not prob

            if not self.topic_in_memory and not self.kwargs["NoTopic"]:
                topic_vector = self.x_topic
                lambda_param = self.g_output_unit_lambda(mem_o_t)
                next_token_no_lambda = tf.cast(tf.argmax(o_t, axis=1),
                o_t = o_t + lambda_param * topic_vector
                lambda_param = tf.zeros(self.batch_size)
                next_token_no_lambda = tf.cast(tf.argmax(o_t, axis=1),

            gumbel_t = add_gumbel(o_t)

            next_token = tf.cast(tf.argmax(gumbel_t, axis=1), tf.int32)

            x_onehot_appr = tf.nn.softmax(
                tf.multiply(gumbel_t, self.temperature, name="gumbel_x_temp"),
            )  # one-hot-like, [batch_size x vocab_size]
            x_tp1 = tf.nn.embedding_lookup(
                next_token)  # embeddings, [batch_size x emb_dim]
            gen_o = gen_o.write(i,
                                                   self.vocab_size, 1.0, 0.0),
                                    1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, [batch_size]
            gen_x_onehot_adv = gen_x_onehot_adv.write(i, x_onehot_appr)

            lambda_values = lambda_values.write(i, tf.squeeze(lambda_param))
            gen_x_no_lambda = gen_x_no_lambda.write(
                i, tf.squeeze(next_token_no_lambda))
            return i + 1, x_tp1, h_t, gen_o, gen_x, gen_x_onehot_adv, lambda_values, gen_x_no_lambda

        _, _, _, gen_o, gen_x, gen_x_onehot_adv, topicness_values, gen_x_no_lambda = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4, _5, _6, _7: i < self.seq_len,
            loop_vars=(tf.constant(0, dtype=tf.int32), self.first_embedding(),
                       self.init_states, gen_o, gen_x, gen_x_onehot_adv,
                       topicness_values, gen_x_no_lambda),

        gen_x = gen_x.stack()  # seq_len x batch_size
        self.gen_x = tf.transpose(gen_x, perm=[1, 0],
                                  name="gen_x_trans")  # batch_size x seq_len

        gen_o = gen_o.stack()
        self.gen_o = tf.transpose(gen_o, perm=[1, 0], name="gen_o_trans")

        gen_x_onehot_adv = gen_x_onehot_adv.stack()
        self.gen_x_onehot_adv = tf.transpose(
            gen_x_onehot_adv, perm=[1, 0, 2],
            name="gen_x_onehot_adv_trans")  # batch_size x seq_len x vocab_size

        topicness_values = topicness_values.stack()  # seq_len x batch_size
        self.topicness_values = tf.transpose(
            topicness_values, perm=[1, 0],
            name="lambda_values_trans")  # batch_size x seq_len

        gen_x_no_lambda = gen_x_no_lambda.stack()  # seq_len x batch_size
        self.gen_x_no_lambda = tf.transpose(
            gen_x_no_lambda, perm=[1, 0],
            name="gen_x_no_lambda_trans")  # batch_size x seq_len
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
           swap_memory=False, infer_shape=True, name=None):
  """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)

  If, however, the function is not expressible as a TensorFlow op, then use

  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)


    fn: The callable to be performed.  It accepts one argument, which will
      have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `dtype` if one is provided, otherwise
      it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, from first to last.

    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]

    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]

    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
  if not callable(fn):
    raise TypeError("fn must be callable.")

  if isinstance(elems, sparse_tensor.SparseTensor):
    raise TypeError(
        "To perform a map on the values of a sparse tensor use either "
        " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
        " SparseTensor(input.indices, map_fn(fn, input.values), "

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  if dtype is None:
    output_is_sequence = input_is_sequence
    output_flatten = input_flatten
    output_pack = input_pack
    output_is_sequence = nest.is_sequence(dtype)
    output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
    def output_pack(x):
      return (nest.pack_sequence_as(dtype, x)
              if output_is_sequence else x[0])

  elems_flat = input_flatten(elems)

  in_graph_mode = context.in_graph_mode()
  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]

    dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
    dtype_flat = output_flatten(dtype)

    # Convert elems to tensor array.
    n = array_ops.shape(elems_flat[0])[0]

    # TensorArrays are always flat
    elems_ta = [
        tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
        for elem in elems_flat]
    # Unpack elements
    elems_ta = [
        elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]

    i = constant_op.constant(0)

    accs_ta = [
        tensor_array_ops.TensorArray(dtype=dt, size=n,
        for dt in dtype_flat]

    def compute(i, tas):
      """The loop body of map_fn.

        i: the loop counter
        tas: the flat TensorArray accumulator list

        (i + 1, tas): the updated counter + updated TensorArrays

        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      packed_values = input_pack([ for elem_ta in elems_ta])
      packed_fn_values = fn(packed_values)
      nest.assert_same_structure(dtype or elems, packed_fn_values)
      flat_fn_values = output_flatten(packed_fn_values)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)]
      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n, compute, (i, accs_ta),
    results_flat = [r.stack() for r in r_a]

    n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0]
    for elem in elems_flat[1:]:
    for r in results_flat:

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:

    return output_pack(results_flat)
    def __init__(self, lstm, update_rate, reward_gamma):
        self.lstm = lstm
        self.update_rate = update_rate

        self.num_emb = self.lstm.num_emb
        self.batch_size = self.lstm.batch_size
        self.emb_dim = self.lstm.emb_dim
        self.hidden_dim = self.lstm.hidden_dim
        self.sequence_length = self.lstm.sequence_length
        self.start_token = tf.identity(self.lstm.start_token)
        self.learning_rate = self.lstm.learning_rate
        self.reward_gamma = reward_gamma

        self.g_embeddings = tf.identity(self.lstm.g_embeddings)
        self.g_recurrent_unit = self.create_recurrent_unit(
        )  # maps h_tm1 to h_t for generator
        self.g_output_unit = self.create_output_unit(
        )  # maps h_t to o_t (output token logits)
        # self.Wi = tf.Print(self.Wi, [self.Wi[1,:10]],message='rollout.wi======',summarize=100)

        # placeholder definition
        self.x = tf.placeholder(tf.int32,
                                    self.batch_size, self.sequence_length
                                ])  # sequence of tokens generated by generator
        self.given_num = tf.placeholder(tf.int32)

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(
                tf.nn.embedding_lookup(self.g_embeddings, self.x),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        ta_x = tensor_array_ops.TensorArray(dtype=tf.int32,
        ta_x = ta_x.unstack(tf.transpose(self.x, perm=[1, 0]))

        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,

        # When current index i < given_num, use the provided tokens as the input at each time step
        def _g_recurrence_1(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            x_tp1 =
            gen_x = gen_x.write(i,
            return i + 1, x_tp1, h_t, given_num, gen_x

        # When current index i >= given_num, start roll-out, use the output as time step t as the input at time step t+1
        def _g_recurrence_2(i, x_t, h_tm1, given_num, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, given_num, gen_x

        i, x_t, h_tm1, given_num, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, given_num, _4: i < given_num,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                                              self.start_token), self.h0,
                       self.given_num, gen_x))

        _, _, _, _, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(i, x_t, h_tm1, given_num, self.gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x,
                                  perm=[1, 0])  # batch_size x seq_length
        # generating sentences
        gen_o_old = tensor_array_ops.TensorArray(dtype=tf.float32,
        gen_x_old = tensor_array_ops.TensorArray(dtype=tf.int32,

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            next_token = tf.cast(
                tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
            clipped_log_prob = tf.log(
                tf.clip_by_value(tf.nn.softmax(o_t), 1e-20, 1.0))
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            gen_o = gen_o.write(
                    tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0),
                                clipped_log_prob), 1))  # [batch_size] , prob
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o_old, self.gen_x_old = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                                              self.start_token), self.h0,
                       gen_o_old, gen_x_old))

        self.gen_x_old = self.gen_x_old.stack()  # seq_length x batch_size
        self.gen_x_old = tf.transpose(self.gen_x_old, perm=[1, 0])

        self.gen_o_old = tf.transpose(self.gen_o_old.stack(), perm=[1, 0])
    def __init__(self,
        self.num_emb = num_emb
        self.batch_size = batch_size
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.sequence_length = sequence_length
        self.start_token = tf.placeholder(tf.int32, [None], name='start_token')
        self.learning_rate = tf.Variable(float(learning_rate), trainable=False)
        self.reward_gamma = reward_gamma
        self.g_params = []
        self.d_params = []
        self.temperature = 1.0
        self.grad_clip = 5.0
        self.fix = fix
        self.graph = graph = 0.2
        self.expected_reward = tf.Variable(tf.zeros([self.sequence_length]))
        self.fix_val = fix_val
        with tf.variable_scope('generator'):
            if self.fix == False:
                self.g_embeddings = tf.Variable(
                    self.init_matrix([self.num_emb, self.emb_dim]))
                self.g_embeddings = tf.Variable(initial_value=self.fix_val,
            self.g_recurrent_unit = self.create_recurrent_unit(
                self.g_params)  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit(
                self.g_params)  # maps h_t to o_t (output token logits)
        self.g_params2 = self.g_params[1:]
        self.g_params3 = self.g_params[:1]
        # placeholder definition
        self.x = tf.placeholder(tf.int32,
                                    self.batch_size, self.sequence_length
                                ])  # sequence of tokens generated by generator
        self.rewards = tf.placeholder(
            tf.float32, shape=[self.batch_size, self.sequence_length - 1
                               ])  # get from rollout policy and discriminator

        # processed for batch
        with tf.device("/gpu:2"):
            self.processed_x = tf.transpose(
                tf.nn.embedding_lookup(self.g_embeddings, self.x[:, 1:]),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        # Initial states
        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        self.h0 = tf.stack([self.h0, self.h0])

        gen_o = tensor_array_ops.TensorArray(dtype=tf.float32,
        gen_x = tensor_array_ops.TensorArray(dtype=tf.int32,
        gen_x = gen_x.write(0, self.start_token)

        def _g_recurrence(i, x_t, h_tm1, gen_o, gen_x):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            log_prob = tf.log(tf.nn.softmax(o_t))
            pro = np.random.rand()
            if self.graph == None:
                next_token = tf.cast(
                    tf.reshape(tf.multinomial(log_prob, 1), [self.batch_size]),
                if pro <
                    pre = gen_x[:, i]
                    next_token = []
                    for i in pre:
                    next_token = np.array(next_token).reshape([-1])
                    next_token = tf.cast(
                        tf.reshape(tf.multinomial(log_prob, 1),
                                   [self.batch_size]), tf.int32)
            x_tp1 = tf.nn.embedding_lookup(self.g_embeddings,
                                           next_token)  # batch x emb_dim
            # gen_o = gen_o.write(i, tf.reduce_sum(tf.multiply(tf.one_hot(next_token, self.num_emb, 1.0, 0.0),
            #                                                  tf.nn.softmax(o_t)), 1))  # [batch_size] , prob
            gen_o = gen_o.write(i, tf.nn.softmax(o_t))
            gen_x = gen_x.write(i, next_token)  # indices, batch_size
            return i + 1, x_tp1, h_t, gen_o, gen_x

        _, _, _, self.gen_o, self.gen_x = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3, _4: i < self.sequence_length,
            loop_vars=(tf.constant(1, dtype=tf.int32),
                                              self.start_token), self.h0,
                       gen_o, gen_x))

        self.gen_x = self.gen_x.stack()  # seq_length x batch_size
        self.gen_x = tf.transpose(self.gen_x,
                                  perm=[1, 0])  # batch_size x seq_length
        self.gen_o = tf.transpose(
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(
            size=self.sequence_length - 1,

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length - 1)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)
        self.ta_emb_x = ta_emb_x

        def _pretrain_recurrence(i, x_t, h_tm1, g_predictions):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = self.g_output_unit(h_t)
            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 =
            return i + 1, x_tp1, h_t, g_predictions

        def exp(x_t, h_tm1):
            h_t = self.g_recurrent_unit(x_t, h_tm1)
            o_t = tf.nn.softmax(self.g_output_unit(h_t))
            return o_t

        self.o_t = exp(
            tf.nn.embedding_lookup(self.g_embeddings, self.start_token),

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length - 1,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                                              self.start_token), self.h0,

        self.g_predictions = tf.transpose(
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        # pretraining loss
        self.pretrain_loss = -tf.reduce_sum(
                self.x[:, 1:], [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                        tf.reshape(self.g_predictions, [-1, self.num_emb]),
                        1e-20, 1.0))) / (
                            (self.sequence_length - 1) * self.batch_size)

        # training updates
        pretrain_opt = self.g_optimizer(self.learning_rate)

        self.pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(self.pretrain_loss, self.g_params), self.grad_clip)
        self.pretrain_updates = pretrain_opt.apply_gradients(
            zip(self.pretrain_grad, self.g_params))
        #pretrain fix embedding
        pretrain_opt2 = self.g_optimizer(self.learning_rate)

        self.pretrain_grad2, _ = tf.clip_by_global_norm(
            tf.gradients(self.pretrain_loss, self.g_params2), self.grad_clip)
        self.pretrain_updates2 = pretrain_opt2.apply_gradients(
            zip(self.pretrain_grad2, self.g_params2))

        pretrain_opt3 = self.g_optimizer(self.learning_rate)
        self.pretrain_grad3, _ = tf.clip_by_global_norm(
            tf.gradients(self.pretrain_loss, self.g_params3), self.grad_clip)
        self.pretrain_updates3 = pretrain_opt3.apply_gradients(
            zip(self.pretrain_grad3, self.g_params3))
        #  Unsupervised Training
        self.g_loss = -tf.reduce_sum(
                    self.x[:, 1:], [-1])), self.num_emb, 1.0, 0.0) * tf.log(
                            tf.reshape(self.g_predictions, [-1, self.num_emb]),
                            1e-20, 1.0)), 1) * tf.reshape(self.rewards, [-1]))
        # self.g_loss = -tf.reduce_sum(
        #     tf.reduce_sum(
        #         tf.one_hot(tf.to_int32(tf.reshape(self.x[:,1:], [-1])), self.num_emb, 1.0, 0.0) * tf.log(
        #             tf.clip_by_value(tf.reshape(self.gen_o[:,1:,:], [-1, self.num_emb]), 1e-20, 1.0)
        #         ), 1) * tf.reshape(self.rewards, [-1])
        # )

        g_opt = self.g_optimizer(self.learning_rate)

        self.g_grad, _ = tf.clip_by_global_norm(
            tf.gradients(self.g_loss, self.g_params), self.grad_clip)
        self.g_updates = g_opt.apply_gradients(zip(self.g_grad, self.g_params))
def matrix_inverse_pth_root(mat_g,
    """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.

  We use an iterative Schur-Newton method from equation 3.2 on page 9 of:

  A Schur-Newton Method for the Matrix p-th Root and its Inverse
  by Chun-Hua Guo and Nicholas J. Higham
  SIAM Journal on Matrix Analysis and Applications,
  2006, Vol. 28, No. 3 : pp. 788-804

    mat_g: the symmetric PSD matrix whose power it to be computed
    mat_g_size: size of mat_g.
    alpha: exponent, must be -1/p for p a positive integer.
    iter_count: Maximum number of iterations.
    epsilon: accuracy indicator, useful for early termination.
    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.


    identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))

    def mat_power(mat_m, p):
        """Computes mat_m^p, for p a positive integer.

    Power p is known at graph compile time, so no need for loop and cond.
      mat_m: a square matrix
      p: a positive integer

        assert p == int(p) and p > 0
        power = None
        while p > 0:
            if p % 2 == 1:
                power = math_ops.matmul(mat_m,
                                        power) if power is not None else mat_m
            p //= 2
            mat_m = math_ops.matmul(mat_m, mat_m)
        return power

    def _iter_condition(i, mat_m, _):
        return math_ops.logical_and(
            i < iter_count,
            math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)

    def _iter_body(i, mat_m, mat_x):
        mat_m_i = (1 - alpha) * identity + alpha * mat_m
        return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha),
                                       mat_m), math_ops.matmul(mat_x, mat_m_i))

    if mat_g_size == 1:
        mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
        damped_mat_g = mat_g + ridge_epsilon * identity
        z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
        # The best value for z is
        # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
        #                 (c_max^{1-alpha} - c_min^{1-alpha})
        # where c_max and c_min are the largest and smallest singular values of
        # damped_mat_g.
        # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
        # Can replace above line by the one below, but it is less accurate,
        # hence needs more iterations to converge.
        # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
        # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
        # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
        # extra iterations.
        _, _, mat_h = control_flow_ops.while_loop(
            _iter_condition, _iter_body,
            [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
    return mat_h
Beispiel #53
def hessians(ys,
    """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.

  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
  where each tensor is the Hessian of `sum(ys)`. This function currently
  only supports evaluating the Hessian with respect to (a list of) one-
  dimensional tensors.

  The Hessian is a matrix of second-order partial derivatives of a scalar
  tensor (see for more details).

    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'hessians'.
    colocate_gradients_with_ops: See `gradients()` documentation for details.
    gate_gradients: See `gradients()` documentation for details.
    aggregation_method: See `gradients()` documentation for details.

    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.

    LookupError: if one of the operations between `xs` and `ys` does not
      have a registered gradient function.
    xs = _AsList(xs)
    kwargs = {
        'colocate_gradients_with_ops': colocate_gradients_with_ops,
        'gate_gradients': gate_gradients,
        'aggregation_method': aggregation_method
    # Compute first-order derivatives and iterate for each x in xs.
    hessians = []
    _gradients = gradients(ys, xs, **kwargs)
    for i, _gradient, x in zip(range(len(xs)), _gradients, xs):
        # Ensure that x is a vector.
        check_rank = check_ops.assert_rank(
            message='Cannot compute Hessian because element %d of `xs` does '
            'not have rank one.' % i)
        with ops.control_dependencies([check_rank]):
            # Declare an iterator and tensor array loop variables for the gradients.
            n = array_ops.size(x)
            loop_vars = [
                array_ops.constant(0, dtypes.int32),
                tensor_array_ops.TensorArray(x.dtype, n)
            # Iterate over all elements of the gradient and compute second order
            # derivatives.
            _, hessian = control_flow_ops.while_loop(
                lambda j, _: j < n, lambda j, result:
                (j + 1, result.write(j,
                                     gradients(_gradient[j], x)[0])),

    return hessians
Beispiel #54
def matrix_exponential(input, name=None):  # pylint: disable=redefined-builtin
    r"""Computes the matrix exponential of one or more square matrices.

  exp(A) = \sum_{n=0}^\infty A^n/n!

  The exponential is computed using a combination of the scaling and squaring
  method and the Pade approximation. Details can be found in:
  Nicholas J. Higham, "The scaling and squaring method for the matrix
  exponential revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.

  The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
  form square matrices. The output is a tensor of the same shape as the input
  containing the exponential for all input submatrices `[..., :, :]`.

    input: A `Tensor`. Must be `float16`, `float32`, `float64`, `complex64`, or
      `complex128` with shape `[..., M, M]`.
    name:  A name to give this `Op` (optional).

    the matrix exponential of the input.

    ValueError: An unsupported type is provided as input.

  Equivalent to scipy.linalg.expm
    with ops.name_scope(name, 'matrix_exponential', [input]):
        matrix = ops.convert_to_tensor(input, name='input')
        if matrix.shape[-2:] == [0, 0]:
            return matrix
        batch_shape = matrix.shape[:-2]
        if not batch_shape.is_fully_defined():
            batch_shape = array_ops.shape(matrix)[:-2]

        # reshaping the batch makes the where statements work better
        matrix = array_ops.reshape(
            array_ops.concat(([-1], array_ops.shape(matrix)[-2:]), axis=0))
        l1_norm = math_ops.reduce_max(math_ops.reduce_sum(
            axis=array_ops.size(array_ops.shape(matrix)) - 2),
                                      axis=-1)[..., array_ops.newaxis,
        const = lambda x: constant_op.constant(x, l1_norm.dtype)

        def _nest_where(vals, cases):
            assert len(vals) == len(cases) - 1
            if len(vals) == 1:
                return array_ops.where_v2(
                    math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
                return array_ops.where_v2(
                    math_ops.less(l1_norm, const(vals[0])), cases[0],
                    _nest_where(vals[1:], cases[1:]))

        if matrix.dtype in [dtypes.float16, dtypes.float32, dtypes.complex64]:
            maxnorm = const(3.925724783138660)
            squarings = math_ops.maximum(
                    math_ops.log(l1_norm / maxnorm) /
                    math_ops.log(const(2.0))), 0)
            u3, v3 = _matrix_exp_pade3(matrix)
            u5, v5 = _matrix_exp_pade5(matrix)
            u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast(
                math_ops.pow(const(2.0), squarings), matrix.dtype))
            conds = (4.258730016922831e-001, 1.880152677804762e+000)
            u = _nest_where(conds, (u3, u5, u7))
            v = _nest_where(conds, (v3, v5, v7))
        elif matrix.dtype in [dtypes.float64, dtypes.complex128]:
            maxnorm = const(5.371920351148152)
            squarings = math_ops.maximum(
                    math_ops.log(l1_norm / maxnorm) /
                    math_ops.log(const(2.0))), 0)
            u3, v3 = _matrix_exp_pade3(matrix)
            u5, v5 = _matrix_exp_pade5(matrix)
            u7, v7 = _matrix_exp_pade7(matrix)
            u9, v9 = _matrix_exp_pade9(matrix)
            u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast(
                math_ops.pow(const(2.0), squarings), matrix.dtype))
            conds = (1.495585217958292e-002, 2.539398330063230e-001,
                     9.504178996162932e-001, 2.097847961257068e+000)
            u = _nest_where(conds, (u3, u5, u7, u9, u13))
            v = _nest_where(conds, (v3, v5, v7, v9, v13))
            raise ValueError(
                'tf.linalg.expm does not support matrices of type %s' %
        numer = u + v
        denom = -u + v
        result = linalg_ops.matrix_solve(denom, numer)
        max_squarings = math_ops.reduce_max(squarings)

        i = const(0.0)
        c = lambda i, r: math_ops.less(i, max_squarings)

        def b(i, r):
            return i + 1, array_ops.where_v2(math_ops.less(i, squarings),
                                             math_ops.matmul(r, r), r)

        _, result = control_flow_ops.while_loop(c, b, [i, result])
        if not matrix.shape.is_fully_defined():
            return array_ops.reshape(
                array_ops.concat((batch_shape, array_ops.shape(result)[-2:]),
        return array_ops.reshape(result,
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
          swap_memory=False, name=None):
  """foldl on the list of tensors unpacked from `elems` on dimension 0.

  This foldl operator repeatedly applies the callable `fn` to a sequence
  of elements from first to last. The elements are made of the tensors
  unpacked from `elems` on dimension 0. The callable fn takes two tensors as
  arguments. The first argument is the accumulated value computed from the
  preceding invocation of fn. If `initializer` is None, `elems` must contain
  at least one element, and its first element is used as the initializer.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is fn(initializer, values[0]).shape`.

    fn: The callable to be performed.
    elems: A tensor to be unpacked on dimension 0.
    initializer: (optional) The initial value for the accumulator.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    name: (optional) Name prefix for the returned tensors.

    A tensor resulting from applying `fn` consecutively to the list of tensors
    unpacked from `elems`, from first to last.

    TypeError: if `fn` is not callable.

    elems = [1, 2, 3, 4, 5, 6]
    sum = foldl(lambda a, x: a + x, elems)
    # sum == 21
  if not callable(fn):
    raise TypeError("fn must be callable.")

  in_graph_mode = context.in_graph_mode()
  with ops.name_scope(name, "foldl", [elems]):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    # Convert elems to tensor array.
    elems = ops.convert_to_tensor(elems, name="elems")
    n = array_ops.shape(elems)[0]
    elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
    elems_ta = elems_ta.unstack(elems)

    if initializer is None:
      a =
      i = constant_op.constant(1)
      a = ops.convert_to_tensor(initializer)
      i = constant_op.constant(0)

    def compute(i, a):
      a = fn(a,
      return [i + 1, a]
    _, r_a = control_flow_ops.while_loop(
        lambda i, a: i < n, compute, [i, a],

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
    return r_a
def resample_at_rate(inputs, rates, scope=None, seed=None, back_prop=False):
    """Given `inputs` tensors, stochastically resamples each at a given rate.

  For example, if the inputs are `[[a1, a2], [b1, b2]]` and the rates
  tensor contains `[3, 1]`, then the return value may look like `[[a1,
  a2, a1, a1], [b1, b2, b1, b1]]`. However, many other outputs are
  possible, since this is stochastic -- averaged over many repeated
  calls, each set of inputs should appear in the output `rate` times
  the number of invocations.

  Uses Knuth's method to generate samples from the poisson
  distribution (but instead of just incrementing a count, actually
  emits the input); this is described at in the section on
  generating Poisson-distributed random variables.

  Note that this method is not appropriate for large rate values: with
  float16 it will stop performing correctly for rates above 9.17;
  float32, 87; and float64, 708. (These are the base-e versions of the
  minimum representable exponent for each type.)

    inputs: A list of tensors, each of which has a shape of `[batch_size, ...]`
    rates: A tensor of shape `[batch_size]` contiaining the resampling rates
           for each input.
    scope: Scope for the op.
    seed: Random seed to use.
    back_prop: Whether to allow back-propagation through this op.

    Selections from the input tensors.

    # TODO(shoutis): Refactor, splitting this up into a poisson draw and a repeat.

    # What this implementation does is loop, simulating the intervals
    # between events by drawing from the exponential distribution
    # (`-log(random_uniform)/rate`), and emitting another copy of the
    # corresponding input so long as sum(intervals) < 1. However, that
    # condition can be transformed into the easier-to-compute condition
    # `product(random_uniforms) > e^-rate`.
    with ops.name_scope(scope, default_name='resample_at_rate', values=inputs):
        floor_vals = math_ops.exp(-rates)

        def _body(chosen_inputs, running_products, idx, output_count):
            """Body of the resampling loop."""
            # Update the running product
            next_running_products = running_products * random_ops.random_uniform(
                shape=array_ops.shape(running_products), seed=seed)

            # Append inputs which still pass the condition:
            indexes = array_ops.reshape(
                array_ops.where(next_running_products > floor_vals), [-1])

            next_output_count = output_count + array_ops.shape(indexes)[0]

            next_chosen_inputs = [
                                       array_ops.gather(inputs[i], indexes))
                for i in range(len(inputs))

            return [
                next_chosen_inputs, next_running_products, idx + 1,

        def _cond(unused_chosen_inputs, running_products, unused_idx,
            """Resampling loop exit condition."""
            return math_ops.reduce_any(running_products > floor_vals)

        initial_chosen_inputs = [
                                         dynamic_size=True) for x in inputs

        resampled_inputs, _, unused_idx, count = control_flow_ops.while_loop(
                array_ops.ones_like(rates),  # initial running_products
                0,  # initial idx
            ],  # initial count

    # Work around TensorArray "Currently only static shapes are supported when
    # concatenating zero-size TensorArrays" limitation:
    def _empty_tensor_like(t):
        result = array_ops.zeros(shape=(array_ops.concat_v2(
            [[0], array_ops.shape(t)[1:]], 0)),
        if t.get_shape().ndims is not None:
            # preserve known shapes
            result.set_shape([0] + t.get_shape()[1:].as_list())
        return result

    return control_flow_ops.cond(
        count > 0,
        lambda: [tensor_array.concat() for tensor_array in resampled_inputs],
        lambda: [_empty_tensor_like(t) for t in inputs])
 def Foo():
   return control_flow_ops.while_loop(lambda i: i < 10,
                                      lambda i: i + x,
def _dynamic_rnn_loop(cell,
    """Internal implementation of Dynamic RNN.

    cell: An instance of RNNCell.
    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
      tuple of such elements.
    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
      `cell.state_size` is a tuple, then this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    parallel_iterations: Positive Python int.
    swap_memory: A Python boolean
    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
    dtype: (optional) Expected dtype of output. If not specified, inferred from

    Tuple `(final_outputs, final_state)`.
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.
      A `Tensor`, or possibly nested tuple of Tensors, matching in length
      and shapes to `initial_state`.

    ValueError: If the input depth cannot be inferred via shape inference
      from the inputs.
    state = initial_state
    assert isinstance(parallel_iterations,
                      int), "parallel_iterations must be int"

    state_size = cell.state_size

    flat_input = nest.flatten(inputs)
    flat_output_size = nest.flatten(cell.output_size)

    # Construct an initial output
    input_shape = array_ops.shape(flat_input[0])
    time_steps = input_shape[0]
    batch_size = input_shape[1]

    inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
                             for input_ in flat_input)

    const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]

    for shape in inputs_got_shape:
        if not shape[2:].is_fully_defined():
            raise ValueError(
                "Input size (depth of inputs) must be accessible via shape inference,"
                " but saw value None.")
        got_time_steps = shape[0].value
        got_batch_size = shape[1].value
        if const_time_steps != got_time_steps:
            raise ValueError(
                "Time steps is not the same for all the elements in the input in a "
        if const_batch_size != got_batch_size:
            raise ValueError(
                "Batch_size is not the same for all the elements in the input."

    # Prepare dynamic conditional copying of state & output
    def _create_zero_arrays(size):
        size = _state_size_with_prefix(size, prefix=[batch_size])
        return array_ops.zeros(array_ops.stack(size),
                               _infer_state_dtype(dtype, state))

    flat_zero_output = tuple(
        _create_zero_arrays(output) for output in flat_output_size)
    zero_output = nest.pack_sequence_as(structure=cell.output_size,

    if sequence_length is not None:
        min_sequence_length = math_ops.reduce_min(sequence_length)
        max_sequence_length = math_ops.reduce_max(sequence_length)

    time = array_ops.constant(0, dtype=dtypes.int32, name="time")

    with ops.name_scope("dynamic_rnn") as scope:
        base_name = scope

    def _create_ta(name, dtype):
        return tensor_array_ops.TensorArray(dtype=dtype,
                                            tensor_array_name=base_name + name)

    output_ta = tuple(
        _create_ta("output_%d" % i, _infer_state_dtype(dtype, state))
        for i in range(len(flat_output_size)))
    input_ta = tuple(
        _create_ta("input_%d" % i, flat_input[0].dtype)
        for i in range(len(flat_input)))

    input_ta = tuple(
        ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))

    def _time_step(time, output_ta_t, state):
        """Take a time step of the dynamic RNN.

      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

      The tuple (time + 1, output_ta_t with updated flow, new_state).

        input_t = tuple( for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):

        input_t = nest.pack_sequence_as(structure=inputs,
        call_cell = lambda: cell(input_t, state)

        if sequence_length is not None:
             new_state) = _rnn_step(time=time,
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))

        return (time + 1, output_ta_t, new_state)

    _, output_final_ta, final_state = control_flow_ops.while_loop(
        cond=lambda time, *_: time < time_steps,
        loop_vars=(time, output_ta, state),

    # Unpack final output if not using output tuples.
    final_outputs = tuple(ta.stack() for ta in output_final_ta)

    # Restore some shape information
    for output, output_size in zip(final_outputs, flat_output_size):
        shape = _state_size_with_prefix(
            output_size, prefix=[const_time_steps, const_batch_size])

    final_outputs = nest.pack_sequence_as(structure=cell.output_size,

    return (final_outputs, final_state)
    def _testScopedExport(self, test_dir, exported_filename, ckpt_filename):
        graph = tf.Graph()
        with graph.as_default():
            # Creates an inference graph.
            # Hidden 1
            colocate_constraint = tf.constant(1.2, name="constraint")
            images = tf.constant(1.2,
                                 shape=[100, 28],
            with tf.name_scope("hidden1"):
                with graph.colocate_with(colocate_constraint.op):
                    weights1 = tf.Variable(tf.truncated_normal(
                        [28, 128], stddev=1.0 / math.sqrt(float(28))),
                # The use of control_flow_ops.cond here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                biases1 = tf.Variable(control_flow_ops.cond(
                    tf.less(random.random(), 0.5), lambda: tf.ones([128]),
                    lambda: tf.zeros([128])),
                hidden1 = tf.nn.relu(tf.matmul(images, weights1) + biases1)

            # Hidden 2
            with tf.name_scope("hidden2"):
                weights2 = tf.Variable(tf.truncated_normal(
                    [128, 32], stddev=1.0 / math.sqrt(float(128))),

                # The use of control_flow_ops.while_loop here is purely for adding test
                # coverage the save and restore of control flow context (which doesn't
                # make any sense here from a machine learning perspective).  The typical
                # biases is a simple Variable without the conditions.
                def loop_cond(it, _):
                    return it < 2

                def loop_body(it, biases2):
                    biases2 += tf.constant(0.1, shape=[32])
                    return it + 1, biases2

                _, biases2 = control_flow_ops.while_loop(
                    loop_cond, loop_body,
                hidden2 = tf.nn.relu(tf.matmul(hidden1, weights2) + biases2)
            # Linear
            with tf.name_scope("softmax_linear"):
                weights3 = tf.Variable(tf.truncated_normal(
                    [32, 10], stddev=1.0 / math.sqrt(float(32))),
                biases3 = tf.Variable(tf.zeros([10]), name="biases")
                logits = tf.matmul(hidden2, weights3) + biases3
                tf.add_to_collection("logits", logits)
            orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
                filename=os.path.join(test_dir, exported_filename),
            self.assertEqual(["biases:0", "weights:0"],
            var_names = [ for _, v in var_list.items()]
            self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],

        return orig_meta_graph
  def testDumpToFileWhileLoop(self):
    with session.Session() as sess:
      num_iter = 10

      # "u" is the Variable being updated in the loop.
      u_name = "testDumpToFileWhileLoop/u"
      u_namespace = u_name.split("/")[0]

      u_init_val = np.array(11.0)
      u_init = constant_op.constant(u_init_val)
      u = variables.Variable(u_init, name=u_name)

      # "v" is the increment.
      v_name = "testDumpToFileWhileLoop/v"
      v_namespace = v_name.split("/")[0]

      v_init_val = np.array(2.0)
      v_init = constant_op.constant(v_init_val)
      v = variables.Variable(v_init, name=v_name)

      i = constant_op.constant(0, name="testDumpToFileWhileLoop/i")

      def cond(i):
        return math_ops.less(i, num_iter)

      def body(i):
        new_u = state_ops.assign_add(u, v)
        new_i = math_ops.add(i, 1)
        op =
        new_i = control_flow_ops.with_dependencies([op], new_i)
        return [new_i]

      loop = control_flow_ops.while_loop(cond, body, [i], parallel_iterations=1)

      # Create RunOptions for debug-watching tensors
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % self._dump_root

      # Add debug tensor watch for u.
      self._addDebugTensorWatch(run_options, u_name, 0, debug_urls=[debug_url])
      # Add debug tensor watch for v.
          run_options, "%s/read" % v_name, 0, debug_urls=[debug_url])
      # Add debug tensor watch for while/Identity.
          run_options, "while/Identity", 0, debug_urls=[debug_url])
      # Add debug tensor watch for while/Add/y.
          run_options, "while/Add/y", 0, debug_urls=[debug_url])

      run_metadata = config_pb2.RunMetadata()
      r =, options=run_options, run_metadata=run_metadata)


      self.assertEqual(num_iter, r)

      u_val_final =
      self.assertAllClose(u_init_val + num_iter * v_init_val, u_val_final)

      # Verify dump files

      self.assertTrue(os.path.isdir(os.path.join(self._dump_root, u_namespace)))
          os.path.isdir(os.path.join(self._dump_root, v_namespace, "v")))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)

      # Expected dumped tensors: u, v/read, 10 iterations of while/Identity,
      # and 10 iterations of while/Add/y.
      self.assertEqual(1 + 1 + num_iter + num_iter, dump.size)

      # Verify tensor values.
      self.assertAllClose([u_init_val], dump.get_tensors(u_name, 0,
      self.assertAllClose([v_init_val], dump.get_tensors("%s/read" % v_name, 0,

      while_id_tensors = dump.get_tensors("while/Identity", 0, "DebugIdentity")
      self.assertEqual(10, len(while_id_tensors))
      for k in xrange(len(while_id_tensors)):
        self.assertAllClose(np.array(k), while_id_tensors[k])

      # Verify ascending timestamps from the while loops.
      while_id_rel_timestamps = dump.get_rel_timestamps("while/Identity", 0,
      self.assertEqual(10, len(while_id_rel_timestamps))
      prev_rel_time = 0
      for rel_time in while_id_rel_timestamps:
        self.assertGreaterEqual(rel_time, prev_rel_time)
        prev_rel_time = rel_time