Exemplo n.º 1
0
def _check_same_outputs(true_graph, false_graph):
  """Raises an error if true_graph and false_graph have different outputs."""

  def error(error_detail):
    raise TypeError(
        "true_fn and false_fn arguments to tf.cond must have the same number, "
        "type, and overall structure of return values.\n"
        "\n"
        "true_fn output:  %s\n"
        "false_fn output: %s\n"
        "\n"
        "Error details:\n"
        "%s" % (true_graph.structured_outputs, false_graph.structured_outputs,
                error_detail))

  try:
    nest.assert_same_structure(true_graph.structured_outputs,
                               false_graph.structured_outputs,
                               expand_composites=True)
  except (ValueError, TypeError) as e:
    error(str(e))

  assert len(true_graph.outputs) == len(false_graph.outputs)
  for true_out, false_out in zip(true_graph.outputs, false_graph.outputs):
    if true_out.dtype != false_out.dtype:
      error("%s and %s have different types" % (true_out, false_out))
Exemplo n.º 2
0
    def compute(i, a_flat, tas):
      """The loop body of scan.

      Args:
        i: the loop counter.
        a_flat: the accumulator value(s), flattened.
        tas: the output accumulator TensorArray(s), flattened.

      Returns:
        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
          updated TensorArrays

      Raises:
        TypeError: if initializer and fn() output structure do not match
        ValueType: if initializer and fn() output lengths do not match
      """
      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
      packed_a = output_pack(a_flat)
      a_out = fn(packed_a, packed_elems)
      nest.assert_same_structure(
          elems if initializer is None else initializer, a_out)
      flat_a_out = output_flatten(a_out)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
      if reverse:
        next_i = i - 1
      else:
        next_i = i + 1
      return (next_i, flat_a_out, tas)
Exemplo n.º 3
0
    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(
          *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(
          args[len_orig_loop_vars:])
Exemplo n.º 4
0
    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs)
def _concrete_function_callable_with(function, inputs, allow_conversion):
  """Returns whether concrete `function` can be called with `inputs`."""
  expected_structure = function.graph.structured_input_signature
  try:
    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  except (TypeError, ValueError):
    return False
  try:
    # Verify that no input elements were dropped during flattening.
    repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
    # TODO(b/129422719): Namedtuple subclasses re-created through
    # saved_model.load don't compare equal in type to the original in
    # assert_same_structure. Fix that and we can take out check_types=False
    # here.
    nest.assert_same_structure(inputs, repacked, check_types=False)
  except (TypeError, ValueError):
    return False

  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      if allow_conversion:
        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
        return False
      if arg.dtype != expected.dtype:
        return False
      if not expected.shape.is_compatible_with(arg.shape):
        return False
    else:
      if arg != expected:
        return False
  return True
Exemplo n.º 6
0
  def testMapStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = (((7, 8), 9), 10, (11, 12))
    structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
    nest.assert_same_structure(structure1, structure1_plus1)
    self.assertAllEqual(
        [2, 3, 4, 5, 6, 7],
        nest.flatten(structure1_plus1))
    structure1_plus_structure2 = nest.map_structure(
        lambda x, y: x + y, structure1, structure2)
    self.assertEqual(
        (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
        structure1_plus_structure2)

    self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

    self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

    with self.assertRaisesRegexp(TypeError, "callable"):
      nest.map_structure("bad", structure1_plus1)

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, 3, (3,))

    with self.assertRaisesRegexp(TypeError, "same sequence type"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
Exemplo n.º 7
0
  def __init__(self, initial_state, mask=None, name="trainable_initial_state"):
    """Constructs the Module that introduces a trainable state in the graph.

    It receives an initial state that will be used as the initial values for the
    trainable variables that the module contains, and optionally a mask that
    indicates the parts of the initial state that should be learnable.

    Args:
      initial_state: tensor or arbitrarily nested iterables of tensors.
      mask: optional boolean mask. It should have the same nested structure as
       the given initial_state.
      name: module name.

    Raises:
      TypeError: if mask is not a list of booleans or None.
    """
    super(TrainableInitialState, self).__init__(name=name)

    # Since python 2.7, DeprecationWarning is ignored by default.
    # Turn on the warning:
    warnings.simplefilter("always", DeprecationWarning)
    warnings.warn("Use the trainable flag in initial_state instead.",
                  DeprecationWarning, stacklevel=2)

    if mask is not None:
      flat_mask = nest.flatten(mask)
      if not all([isinstance(m, bool) for m in flat_mask]):
        raise TypeError("Mask should be None or a list of boolean values.")
      nest.assert_same_structure(initial_state, mask)

    self._mask = mask
    self._initial_state = initial_state
Exemplo n.º 8
0
        def body(time, elements_finished, current_input, emit_ta, state, loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
                next_time, next_output, cell_state, loop_state
            )

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(
                        lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
                    )
                    for (current_i, candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]

            emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
Exemplo n.º 9
0
def _is_flat(sequence):
  sequence_flat = nest.flatten(sequence)
  try:
    nest.assert_same_structure(sequence_flat, sequence)
    return True
  except ValueError:
    return False
  except TypeError:
    return False
Exemplo n.º 10
0
 def testNestAssertSameStructureCompositeMismatch(self,
                                                  s1,
                                                  s2,
                                                  error=ValueError):
   # s1 and s2 have the same structure if expand_composites=False; but
   # different structures if expand_composites=True.
   nest.assert_same_structure(s1, s2, expand_composites=False)
   nest.assert_shallow_structure(s1, s2, expand_composites=False)
   with self.assertRaises(error):  # pylint: disable=g-error-prone-assert-raises
     nest.assert_same_structure(s1, s2, expand_composites=True)
Exemplo n.º 11
0
  def _assert_correct_outputs(self, initial_state_):
    nest.assert_same_structure(initial_state_, self.decoder_cell.state_size)
    nest.assert_same_structure(initial_state_, self.encoder_outputs.final_state)

    encoder_state_flat = nest.flatten(self.encoder_outputs.final_state)
    with self.test_session() as sess:
      encoder_state_flat_ = sess.run(encoder_state_flat)

    initial_state_flat_ = nest.flatten(initial_state_)
    for e_dec, e_enc in zip(initial_state_flat_, encoder_state_flat_):
      np.testing.assert_array_equal(e_dec, e_enc)
Exemplo n.º 12
0
 def insert(self, keys, values):
   nest.assert_same_structure(self._hash_tables, values)
   # Avoid race conditions by requiring that all inputs are computed before any
   # inserts happen (an issue if one key's update relies on another's value).
   values_flat = [array_ops.identity(value) for value in nest.flatten(values)]
   with ops.control_dependencies(values_flat):
     insert_ops = [hash_table.insert(keys, value)
                   for hash_table, value
                   in zip(nest.flatten(self._hash_tables),
                          values_flat)]
   return control_flow_ops.group(*insert_ops)
Exemplo n.º 13
0
  def run_and_report(self, s1, s2, name):
    burn_iter, test_iter = 100, 30000

    for _ in xrange(burn_iter):
      nest.assert_same_structure(s1, s2)

    t0 = time.time()
    for _ in xrange(test_iter):
      nest.assert_same_structure(s1, s2)
    t1 = time.time()

    self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
                          name=name)
Exemplo n.º 14
0
  def _maybe_copy_some_through():
    """Run RNN step.  Pass through either no or some past state."""
    new_output, new_state = call_cell()

    nest.assert_same_structure(state, new_state)

    flat_new_state = nest.flatten(new_state)
    flat_new_output = nest.flatten(new_output)
    return control_flow_ops.cond(
        # if t < min_seq_len: calculate and return everything
        time < min_sequence_length, lambda: flat_new_output + flat_new_state,
        # else copy some of it through
        lambda: _copy_some_through(flat_new_output, flat_new_state))
Exemplo n.º 15
0
    def body(time, elements_finished, current_input,
             emit_ta, state, loop_state):
      """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
      (next_output, cell_state) = cell(current_input, state)

      nest.assert_same_structure(state, cell_state)
      nest.assert_same_structure(cell.output_size, next_output)

      next_time = time + 1
      (next_finished, next_input, next_state, emit_output,
       next_loop_state) = loop_fn(
           next_time, next_output, cell_state, loop_state)

      nest.assert_same_structure(state, next_state)
      nest.assert_same_structure(current_input, next_input)
      nest.assert_same_structure(emit_ta, emit_output)

      # If loop_fn returns None for next_loop_state, just reuse the
      # previous one.
      loop_state = loop_state if next_loop_state is None else next_loop_state

      def _copy_some_through(current, candidate):
        """Copy some tensors through via array_ops.where."""
        def copy_fn(cur_i, cand_i):
          return _on_device(
              lambda: array_ops.where(elements_finished, cur_i, cand_i),
              device=cand_i.op.device)
        return nest.map_structure(copy_fn, current, candidate)

      emit_output = _copy_some_through(zero_emit, emit_output)
      next_state = _copy_some_through(state, next_state)

      emit_ta = nest.map_structure(
          lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)

      elements_finished = math_ops.logical_or(elements_finished, next_finished)

      return (next_time, elements_finished, next_input,
              emit_ta, next_state, loop_state)
Exemplo n.º 16
0
def check_mutation(n1, n2):
  """Check if two list of arguments are exactly the same."""
  errmsg = ("Function to be traced should not modify structure of input "
            "arguments. Check if your function has list and dictionary "
            "operations that alter input arguments, "
            "such as `list.pop`, `list.append`")
  try:
    nest.assert_same_structure(n1, n2)
  except ValueError:
    raise ValueError(errmsg)

  for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
    if arg1 is not arg2:
      raise ValueError(errmsg)
Exemplo n.º 17
0
  def testInitialStateComputation(self, tuple_state, mask):
    if tuple_state:
      initial_state = (tf.fill([BATCH_SIZE, 6], 2),
                       (tf.fill([BATCH_SIZE, 7], 3),
                        tf.fill([BATCH_SIZE, 8], 4)))
    else:
      initial_state = tf.fill([BATCH_SIZE, 9], 10)

    trainable_state_module = snt.TrainableInitialState(initial_state, mask=mask)
    trainable_state = trainable_state_module()
    flat_trainable_state = nest.flatten(trainable_state)
    nest.assert_same_structure(initial_state, trainable_state)
    flat_initial_state = nest.flatten(initial_state)
    if mask is not None:
      flat_mask = nest.flatten(mask)
    else:
      flat_mask = (True,) * len(flat_initial_state)

    self.evaluate(tf.global_variables_initializer())

    # Check all variables are initialized correctly and return a state that
    # has the same as it is provided.
    for trainable_state, initial_state in zip(flat_trainable_state,
                                              flat_initial_state):
      self.assertAllEqual(
          self.evaluate(trainable_state), self.evaluate(initial_state))

    # Change the value of all the trainable variables to ones.
    for variable in tf.trainable_variables():
      self.evaluate(tf.assign(variable, tf.ones_like(variable)))

    # In eager mode to re-evaluate the module we must re-connect it.
    trainable_state = trainable_state_module()
    flat_trainable_state = nest.flatten(trainable_state)

    # Check that the values of the initial_states have changed if and only if
    # they are trainable.
    for trainable_state, initial_state, mask in zip(flat_trainable_state,
                                                    flat_initial_state,
                                                    flat_mask):
      trainable_state_value = self.evaluate(trainable_state)
      initial_state_value = self.evaluate(initial_state)
      if mask:
        expected_value = np.ones_like(initial_state_value)
      else:
        expected_value = initial_state_value

      self.assertAllEqual(trainable_state_value, expected_value)
Exemplo n.º 18
0
  def test_convert_to_generator_like(self, input_fn, inputs):
    expected_batches = 5
    data = input_fn(self, inputs, expected_batches)

    # Dataset and Iterator not supported in Legacy Graph mode.
    if (not context.executing_eagerly() and
        isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))):
      return

    generator, steps = training_generator.convert_to_generator_like(
        data, batch_size=2, steps_per_epoch=expected_batches)
    self.assertEqual(steps, expected_batches)

    for _ in range(expected_batches):
      outputs = next(generator)
    nest.assert_same_structure(outputs, inputs)
 def testNestAssertSameStructure(self):
   st1 = sparse_tensor.SparseTensor([[0]], [0], [100])
   st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100])
   test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape)
   nest.assert_same_structure(st1, st2, expand_composites=False)
   nest.assert_same_structure(st1, st2, expand_composites=True)
   nest.assert_same_structure(st1, test, expand_composites=False)
   with self.assertRaises(TypeError):
     nest.assert_same_structure(st1, test, expand_composites=True)
Exemplo n.º 20
0
  def testMapStructure(self):
    structure1 = (((1, 2), 3), 4, (5, 6))
    structure2 = (((7, 8), 9), 10, (11, 12))
    structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
    nest.assert_same_structure(structure1, structure1_plus1)
    self.assertAllEqual(
        [2, 3, 4, 5, 6, 7],
        nest.flatten(structure1_plus1))
    structure1_plus_structure2 = nest.map_structure(
        lambda x, y: x + y, structure1, structure2)
    self.assertEqual(
        (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
        structure1_plus_structure2)

    self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

    self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

    with self.assertRaisesRegexp(TypeError, "callable"):
      nest.map_structure("bad", structure1_plus1)

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, 3, (3,))

    with self.assertRaisesRegexp(TypeError, "same sequence type"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

    structure1_list = [[[1, 2], 3], 4, [5, 6]]
    with self.assertRaisesRegexp(TypeError, "same sequence type"):
      nest.map_structure(lambda x, y: None, structure1, structure1_list)

    nest.map_structure(lambda x, y: None, structure1, structure1_list,
                       check_types=False)

    with self.assertRaisesRegexp(ValueError, "same nested structure"):
      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
                         check_types=False)

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, foo="a")

    with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
      nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
Exemplo n.º 21
0
  def __call__(self, *args):
    nest.assert_same_structure(self.shape_and_dtypes, args, check_types=False)
    if not all([
        shape.is_compatible_with(arg.shape)
        for shape, arg in zip(self.flattened_shapes, nest.flatten(args))
    ]):
      raise ValueError(
          "Declared shapes do not match argument shapes: Expected %s, found %s."
          % (self.flattened_shapes, [arg.shape for arg in nest.flatten(args)]))

    initialized = [resource_variable_ops.var_is_initialized_op(
        v.handle).numpy() for v in self._call_fn.variables]
    if all(x for x in initialized):
      return self._call_fn(*args)
    elif all(not x for x in initialized):
      return self._init_fn(*args)
    else:
      raise ValueError("Some, but not all, variables are initialized.")
Exemplo n.º 22
0
  def testMapStructureWithStrings(self):
    inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
    inp_b = NestTest.ABTuple(a=2, b=(1, 3))
    out = nest.map_structure(lambda string, repeats: string * repeats,
                             inp_a,
                             inp_b)
    self.assertEqual("foofoo", out.a)
    self.assertEqual("bar", out.b[0])
    self.assertEqual("bazbazbaz", out.b[1])

    nt = NestTest.ABTuple(a=("something", "something_else"),
                          b="yet another thing")
    rev_nt = nest.map_structure(lambda x: x[::-1], nt)
    # Check the output is the correct structure, and all strings are reversed.
    nest.assert_same_structure(nt, rev_nt)
    self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
    self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
    self.assertEqual(nt.b[::-1], rev_nt.b)
Exemplo n.º 23
0
  def _verify_structure_shapes_types(self, left, right):
    """Verify that the structure, shapes and types of left are same as right."""
    nest.assert_same_structure(left, right)
    flat_left = nest.flatten(left)
    flat_right = nest.flatten(right)
    assert len(flat_left) == len(flat_right), (
        "Length of left {} and right {} should be same.".
        format(len(flat_left), len(flat_right)))

    for o, i in zip(flat_left, flat_right):
      # TODO(priyag): Add checks for other types like IndexedSlices.
      if isinstance(o, ops.Tensor):
        assert isinstance(i, ops.Tensor)
        assert o.shape == i.shape, (
            "Shape {} of left {} doesn't match shape {} of right {}.".
            format(o.shape, o, i.shape, i))
        assert o.dtype == i.dtype, (
            "Dtype {} of left {} doesn't match dtype {} of right {}.".
            format(o.dtype, o, i.dtype, i))
Exemplo n.º 24
0
def _check_same_outputs(true_graph, false_graph):
  """Raises an error if true_graph and false_graph have different outputs."""
  true_output_types = [t.dtype for t in true_graph.outputs]
  false_output_types = [t.dtype for t in false_graph.outputs]
  if (len(true_graph.outputs) != len(false_graph.outputs) or
      true_output_types != false_output_types):
    raise TypeError(
        "true_fn() and false_fn() must return the same number and type of "
        "arguments, got:\n"
        "  true_fn: %s\n"
        "  false_fn: %s" % (true_output_types, false_output_types))

  # Make sure `structured_outputs` for both graphs have the same structure.
  try:
    nest.assert_same_structure(true_graph.structured_outputs,
                               false_graph.structured_outputs)
  except (ValueError, TypeError) as e:
    raise ValueError("Outputs of true_fn and false_fn must have the same "
                     "structure: %s" % str(e))
Exemplo n.º 25
0
  def __call__(self, inputs, state, scope=None):
    """Run the cell and add its inputs to its outputs.

    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.

    Returns:
      Tuple of cell outputs and new state.

    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
    output, new_state = self._cell(inputs, state, scope=scope)
    nest.assert_same_structure(inputs, output)
    res_output = nest.map_structure(
        lambda inp, out: inp + out, inputs, output)
    return (res_output, new_state)
Exemplo n.º 26
0
    def compute(i, tas):
      """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
      packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
      packed_fn_values = fn(packed_values)
      nest.assert_same_structure(dtype or elems, packed_fn_values)
      flat_fn_values = output_flatten(packed_fn_values)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)]
      return (i + 1, tas)
Exemplo n.º 27
0
def assert_state_is_compatible(expected_state, state):
    """Asserts that states are compatible.

    Args:
        expected_state: The reference state.
        state: The state that must be compatible with :obj:`expected_state`.

    Raises:
      ValueError: if the states are incompatible.
    """
    # Check structure compatibility.
    nest.assert_same_structure(expected_state, state)

    # Check shape compatibility.
    expected_state_flat = nest.flatten(expected_state)
    state_flat = nest.flatten(state)

    for x, y in zip(expected_state_flat, state_flat):
        if tensor_util.is_tensor(x):
            with_same_shape(x, y)
Exemplo n.º 28
0
def gnmt_residual_fn(inputs, outputs):
  """Residual function that handles different inputs and outputs inner dims.

  Args:
    inputs: cell inputs, this is actual inputs concatenated with the attention
      vector.
    outputs: cell outputs

  Returns:
    outputs + actual inputs
  """
  def split_input(inp, out):
    out_dim = out.get_shape().as_list()[-1]
    inp_dim = inp.get_shape().as_list()[-1]
    return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)
  actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)
  def assert_shape_match(inp, out):
    inp.get_shape().assert_is_compatible_with(out.get_shape())
  nest.assert_same_structure(actual_inputs, outputs)
  nest.map_structure(assert_shape_match, actual_inputs, outputs)
  return nest.map_structure(lambda inp, out: inp + out, actual_inputs, outputs)
Exemplo n.º 29
0
    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
      """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)
      next_finished = math_ops.logical_or(decoder_finished, finished)
      if maximum_iterations is not None:
        next_finished = math_ops.logical_or(
            next_finished, time + 1 >= maximum_iterations)
      next_sequence_lengths = array_ops.where(
          math_ops.logical_and(math_ops.logical_not(finished), next_finished),
          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
          sequence_lengths)

      nest.assert_same_structure(state, decoder_state)
      nest.assert_same_structure(outputs_ta, next_outputs)
      nest.assert_same_structure(inputs, next_inputs)

      # Zero out output values past finish
      if impute_finished:
        emit = nest.map_structure(
            lambda out, zero: array_ops.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs

      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                      outputs_ta, emit)
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
              next_sequence_lengths)
Exemplo n.º 30
0
  def body(time, outputs_ta, state, inputs, finished):
    """Internal while_loop body.

    Args:
      time: scalar int32 tensor.
      outputs_ta: structure of TensorArray.
      state: (structure of) state tensors and TensorArrays.
      inputs: (structure of) input tensors.
      finished: 1-D bool tensor.

    Returns:
      `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
    """
    (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
        time, inputs, state)
    next_finished = math_ops.logical_or(decoder_finished, finished)

    nest.assert_same_structure(state, decoder_state)
    nest.assert_same_structure(outputs_ta, next_outputs)
    nest.assert_same_structure(inputs, next_inputs)

    # Zero out output values past finish
    emit = nest.map_structure(
        lambda out, zero: array_ops.where(finished, zero, out), next_outputs,
        zero_outputs)

    # Copy through states past finish
    def _maybe_copy_state(new, cur):
      return (new if isinstance(cur, tensor_array_ops.TensorArray) else
              array_ops.where(finished, cur, new))

    next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
    outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    outputs_ta, emit)
    return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
Exemplo n.º 31
0
  def testInitialStateTuple(self, trainable, use_custom_initial_value,
                            state_size):
    batch_size = 6

    # Set the attribute to the class since it we can't set properties of
    # abstract classes
    snt.RNNCore.state_size = state_size
    flat_state_size = nest.flatten(state_size)
    core = snt.RNNCore(name="dummy_core")
    if use_custom_initial_value:
      flat_initializer = [tf.constant_initializer(2)] * len(flat_state_size)
      trainable_initializers = nest.pack_sequence_as(
          structure=state_size, flat_sequence=flat_initializer)
    else:
      trainable_initializers = None
    initial_state = core.initial_state(
        batch_size, dtype=tf.float32, trainable=trainable,
        trainable_initializers=trainable_initializers)

    nest.assert_same_structure(initial_state, state_size)
    flat_initial_state = nest.flatten(initial_state)

    for state, size in zip(flat_initial_state, flat_state_size):
      self.assertEqual(state.get_shape(), [batch_size, size])

    with self.test_session() as sess:
      tf.global_variables_initializer().run()
      flat_initial_state_value = sess.run(flat_initial_state)
      for value, size in zip(flat_initial_state_value, flat_state_size):
        expected_initial_state = np.empty([batch_size, size])
        if not trainable:
          expected_initial_state.fill(0)
        elif use_custom_initial_value:
          expected_initial_state.fill(2)
        else:
          value_row = value[0]
          expected_initial_state = np.tile(value_row, (batch_size, 1))
        self.assertAllClose(value, expected_initial_state)
Exemplo n.º 32
0
    def testMapStructureOverPlaceholders(self):
        inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
                 array_ops.placeholder(dtypes.float32, shape=[3, 7]))
        inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
                 array_ops.placeholder(dtypes.float32, shape=[3, 7]))

        output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

        nest.assert_same_structure(output, inp_a)
        self.assertShapeEqual(np.zeros((3, 4)), output[0])
        self.assertShapeEqual(np.zeros((3, 7)), output[1])

        feed_dict = {
            inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
            inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
        }

        with self.cached_session() as sess:
            output_np = sess.run(output, feed_dict=feed_dict)
        self.assertAllClose(output_np[0],
                            feed_dict[inp_a][0] + feed_dict[inp_b][0])
        self.assertAllClose(output_np[1],
                            feed_dict[inp_a][1] + feed_dict[inp_b][1])
Exemplo n.º 33
0
    def _build(self, inputs):
        """Passes inputs to the initial states of decoder.

        :attr:`inputs` must either have the same structure, or the same number
        of elements with the decoder state.

        Args:
            inputs: The input (structure of) tensors to pass forward.

        Returns:
            The input (structure of) tensors that might be re-packed to have
            the same structure with decoder state.
        """
        output = inputs
        try:
            nest.assert_same_structure(inputs, self._output_size)
        except (ValueError, TypeError):
            flat_input = nest.flatten(inputs)
            output = nest.pack_sequence_as(self._output_size, flat_input)

        self._built = True

        return output
Exemplo n.º 34
0
        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
            packed_values = input_pack(
                [elem_ta.read(i) for elem_ta in elems_ta])
            packed_fn_values = fn(packed_values)
            nest.assert_same_structure(dtype or elems, packed_fn_values)
            flat_fn_values = output_flatten(packed_fn_values)
            tas = [
                ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)
            ]
            return (i + 1, tas)
Exemplo n.º 35
0
def gnmt_residual_fn(inputs, outputs):
    """Residual function that handles different inputs and outputs inner dims.
    Args:
    inputs: cell inputs, this is actual inputs concatenated with the attention
        vector. refer to GNMT RNN cell
    outputs: cell outputs
    Returns:
    outputs + actual inputs
    """
    def split_input(inp, out):
        out_dim = out.get_shape().as_list()[-1]
        inp_dim = inp.get_shape().as_list()[-1]
        return tf.split(inp, [out_dim, inp_dim - out_dim], axis=-1)

    actual_inputs, _ = nest.map_structure(split_input, inputs, outputs)

    def assert_shape_match(inp, out):
        inp.get_shape().assert_is_compatible_with(out.get_shape())

    nest.assert_same_structure(actual_inputs, outputs)
    nest.map_structure(assert_shape_match, actual_inputs, outputs)
    return nest.map_structure(lambda inp, out: inp + out, actual_inputs,
                              outputs)
Exemplo n.º 36
0
    def __call__(self, inputs, state, scope=None):
        """Run the cell and add its inputs to its outputs.
    Args:
      inputs: cell inputs.
      state: cell state.
      scope: optional cell scope.
    Returns:
      Tuple of cell outputs and new state.
    Raises:
      TypeError: If cell inputs and outputs have different structure (type).
      ValueError: If cell inputs and outputs have different structure (value).
    """
        outputs, new_state = self._cell(inputs, state, scope=scope)
        nest.assert_same_structure(inputs, outputs)

        # Ensure shapes match
        def assert_shape_match(inp, out):
            inp.get_shape().assert_is_compatible_with(out.get_shape())

        nest.map_structure(assert_shape_match, inputs, outputs)
        res_outputs = nest.map_structure(lambda inp, out: inp + out, inputs,
                                         outputs)
        return (res_outputs, new_state)
Exemplo n.º 37
0
def handle_partial_sample_weights(outputs,
                                  sample_weights,
                                  sample_weight_modes,
                                  check_all_flat=False):
    """Adds 1.0 as sample weights for the outputs for which there is no weight.

  Args:
    outputs: List of model outputs.
    sample_weights: List of sample weight inputs.
    sample_weight_modes: List of sample weight modes or None.
    check_all_flat: Ensure that inputs are not nested structures. This is not
      a free check, so we may not want to run it eagerly every iteration.

  Returns:
    Tuple of sample weights, one sample weight for every output, and booleans
    describing the raw sample weights.
  """
    any_sample_weight = sample_weights is not None and any(
        w is not None for w in sample_weights)
    partial_sample_weight = any_sample_weight and any(w is None
                                                      for w in sample_weights)

    if not any_sample_weight:
        return None, any_sample_weight, partial_sample_weight

    if not partial_sample_weight:
        return sample_weights, any_sample_weight, partial_sample_weight

    if check_all_flat:
        nest.assert_same_structure(list_to_tuple(sample_weights),
                                   list_to_tuple(nest.flatten(sample_weights)))
        nest.assert_same_structure(list_to_tuple(outputs),
                                   list_to_tuple(nest.flatten(outputs)))
        if sample_weight_modes is not None:
            nest.assert_same_structure(sample_weight_modes,
                                       nest.flatten(sample_weight_modes))

    new_sample_weights = []
    for i, sw in enumerate(sample_weights):
        if sw is None:
            as_numpy = isinstance(outputs[i], np.ndarray)
            output = outputs[i]
            output_shape = output.shape if as_numpy else array_ops.shape(
                output)

            is_temporal = (sample_weight_modes is not None
                           and sample_weight_modes[i] == 'temporal')
            sw_shape = (output_shape[0], output_shape[1]) if is_temporal else (
                output_shape[0], )

            new_sample_weights.append(
                np.ones(sw_shape) if as_numpy else array_ops.ones(sw_shape))

        else:
            new_sample_weights.append(sw)
    return (list_to_tuple(new_sample_weights), any_sample_weight,
            partial_sample_weight)
Exemplo n.º 38
0
def _concrete_function_callable_with(function, inputs, allow_conversion):
    """Returns whether concrete `function` can be called with `inputs`."""
    expected_structure = function.graph.structured_input_signature
    try:
        flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
    except (TypeError, ValueError):
        return False
    try:
        # Verify that no input elements were dropped during flattening.
        repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
        # TODO(b/129422719): Namedtuple subclasses re-created through
        # saved_model.load don't compare equal in type to the original in
        # assert_same_structure. Fix that and we can take out check_types=False
        # here.
        nest.assert_same_structure(inputs, repacked, check_types=False)
    except (TypeError, ValueError):
        return False

    for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
        if isinstance(expected, tensor_spec.TensorSpec):
            if allow_conversion:
                arg = _try_convert_to_tensor_spec(arg,
                                                  dtype_hint=expected.dtype)
            if not _is_tensor(arg) and not isinstance(arg,
                                                      tensor_spec.TensorSpec):
                return False
            if arg.dtype != expected.dtype:
                return False
            if not expected.shape.is_compatible_with(arg.shape):
                return False
        elif isinstance(expected, type_spec.TypeSpec):
            return expected.is_compatible_with(arg)
        elif (_is_tensor(arg)
              and id(arg) != id(expected)) or (not _is_tensor(arg)
                                               and arg != expected):
            return False
    return True
Exemplo n.º 39
0
    def __init__(self,
                 initial_state,
                 mask=None,
                 name="trainable_initial_state"):
        """Constructs the Module that introduces a trainable state in the graph.

    It receives an initial state that will be used as the initial values for the
    trainable variables that the module contains, and optionally a mask that
    indicates the parts of the initial state that should be learnable.

    Args:
      initial_state: tensor or arbitrarily nested iterables of tensors.
      mask: optional boolean mask. It should have the same nested structure as
       the given initial_state.
      name: module name.

    Raises:
      TypeError: if mask is not a list of booleans or None.
    """
        super(TrainableInitialState, self).__init__(name=name)

        # Since python 2.7, DeprecationWarning is ignored by default.
        # Turn on the warning:
        warnings.simplefilter("always", DeprecationWarning)
        warnings.warn("Use the trainable flag in initial_state instead.",
                      DeprecationWarning,
                      stacklevel=2)

        if mask is not None:
            flat_mask = nest.flatten(mask)
            if not all([isinstance(m, bool) for m in flat_mask]):
                raise TypeError(
                    "Mask should be None or a list of boolean values.")
            nest.assert_same_structure(initial_state, mask)

        self._mask = mask
        self._initial_state = initial_state
 def test_state_override(self):
     test_start_state = (numpy.array([[2, 3, 4]]), (numpy.array([2]),
                                                    numpy.array([[3.,
                                                                  5.]])))
     data = {
         feature_keys.FilteringFeatures.TIMES: numpy.arange(5),
         feature_keys.FilteringFeatures.VALUES: numpy.zeros(shape=[5, 3])
     }
     features, _ = input_pipeline.WholeDatasetInputFn(
         input_pipeline.NumpyReader(data))()
     features[feature_keys.FilteringFeatures.STATE_TUPLE] = test_start_state
     stub_model = _StateOverrideModel()
     chainer = state_management.ChainingStateManager()
     stub_model.initialize_graph()
     chainer.initialize_graph(model=stub_model)
     model_outputs = chainer.define_loss(model=stub_model,
                                         features=features,
                                         mode=estimator_lib.ModeKeys.EVAL)
     with train.MonitoredSession() as session:
         end_state = session.run(model_outputs.end_state)
     nest.assert_same_structure(test_start_state, end_state)
     for expected, received in zip(nest.flatten(test_start_state),
                                   nest.flatten(end_state)):
         self.assertAllEqual(expected, received)
Exemplo n.º 41
0
    def compute(i, a_flat, tas):
      """The loop body of scan.

      Args:
        i: the loop counter.
        a_flat: the accumulator value(s), flattened.
        tas: the output accumulator TensorArray(s), flattened.

      Returns:
        [i + 1, a_flat, tas]: the updated counter + new accumulator values +
          updated TensorArrays

      Raises:
        TypeError: if initializer and fn() output structure do not match
        ValueType: if initializer and fn() output lengths do not match
      """
      packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
      packed_a = output_pack(a_flat)
      a_out = fn(packed_a, packed_elems)
      nest.assert_same_structure(
          elems if initializer is None else initializer, a_out)
      flat_a_out = output_flatten(a_out)
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
      return (i + 1, flat_a_out, tas)
Exemplo n.º 42
0
    def testFromFunctionInputSignatureForPerReplicaValues(
            self, distribution, enable_get_next_as_optional, drop_remainder):
        # Create files that produce partial/empty batches at different batch. Note
        # that some worker will get empty batches even when drop_remainder=True.
        fname1 = os.path.join(self.get_temp_dir(), "1.txt")
        _create_text_file(fname1, 5)
        fname2 = os.path.join(self.get_temp_dir(), "2.txt")
        _create_text_file(fname2, 9)

        def dataset_fn(input_context):
            dataset = dataset_ops.DatasetV2.from_tensor_slices(
                [fname1, fname2])
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)
            return readers.TextLineDatasetV2(dataset).map(
                string_ops.string_to_number).batch(
                    input_context.get_per_replica_batch_size(4),
                    drop_remainder=drop_remainder)

        distribution.extended.experimental_enable_get_next_as_optional = (
            enable_get_next_as_optional)
        ds = distribution.experimental_distribute_datasets_from_function(
            dataset_fn)
        _check_type_spec_structure(iter(ds))
        element_spec = ds.element_spec
        iter_element_spec = iter(ds).element_spec
        nest.assert_same_structure(element_spec, iter_element_spec)
        self.assertAllEqual(nest.flatten(element_spec),
                            nest.flatten(iter_element_spec))

        @def_function.function(input_signature=[element_spec])
        def process_inputs(inputs):
            distribution.run(lambda inputs: inputs, args=(inputs, ))

        for x in ds:
            process_inputs(x)
Exemplo n.º 43
0
  def _make_bridging_callable(
      generator_fn, wrap_in_tuple, peek, elements_to_keep,
      partial_sample_weight, sample_weight_modes):
    """Optional compatibility layer between user's data and Dataset."""
    must_prune_nones = (elements_to_keep != len(peek))
    try:
      nest.assert_same_structure(peek, nest._list_to_tuple(peek))  # pylint: disable=protected-access
      must_extract_lists = False
    except TypeError:
      must_extract_lists = True

    # No additional transformations are needed.
    if not (wrap_in_tuple or must_extract_lists or must_prune_nones or
            partial_sample_weight):
      return generator_fn

    def wrapped_generator():
      """Remove Nones and lists before invoking Dataset.from_generator."""
      for batch in generator_fn():
        if wrap_in_tuple:
          batch = (batch,)

        if must_extract_lists:
          batch = nest._list_to_tuple(batch)  # pylint: disable=protected-access

        if must_prune_nones:
          batch = batch[:elements_to_keep]

        if partial_sample_weight:
          sample_weights, _, _ = training_utils.handle_partial_sample_weights(
              batch[1], batch[2], sample_weight_modes, check_all_flat=False)
          batch = batch[:2] + (sample_weights,)

        yield batch

    return wrapped_generator
Exemplo n.º 44
0
    def _build(self, inputs):
        """Transforms inputs to have the same structure as with
        :attr:`output_size`. Values of the inputs are not changed.

        :attr:`inputs` must either have the same structure, or have the same
        number of elements with :attr:`output_size`.

        Args:
            inputs: The input (structure of) tensor to pass forward.

        Returns:
            A (structure of) tensors that re-packs `inputs` to have
            the specified structure of `output_size`.
        """
        output = inputs
        try:
            nest.assert_same_structure(inputs, self._output_size)
        except (ValueError, TypeError):
            flat_input = nest.flatten(inputs)
            output = nest.pack_sequence_as(self._output_size, flat_input)

        self._built = True

        return output
Exemplo n.º 45
0
        def body(time, outputs_ta, state, inputs, finished):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: 1-D bool tensor.

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            next_finished = math_ops.logical_or(decoder_finished, finished)
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)

        #判断结构是否一样
            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished)
Exemplo n.º 46
0
        def body(time, elements_finished, current_input, state, beam_seq,
                 beam_prob, emit_ta, loop_state):
            """Internal while loop body for raw_rnn.

            Args:
              time: time scalar.
              elements_finished: batch-size vector.
              current_input: possibly nested tuple of input tensors.
              emit_ta: possibly nested tuple of output TensorArrays.
              state: possibly nested tuple of state tensors.
              loop_state: possibly nested tuple of loop state tensors.

            Returns:
              Tuple having the same size as Args but with updated values.
            """
            dummy = array_ops.zeros(
                shape=[tf.shape(beam_seq)[0],
                       tf.shape(beam_seq)[1], 20],
                dtype=tf.int32)

            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)
            #cell_output, cell_state, beam_seq, beam_prob, finished, emit_ta, loop_state

            (next_input, elements_finished, next_state, beam_seq, beam_prob,
             emit_ta, next_loop_state) = loop_fn(next_output, cell_state,
                                                 beam_seq, beam_prob,
                                                 elements_finished, emit_ta,
                                                 loop_state)
            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            next_time = time + 1
            return (next_time, elements_finished, next_input, next_state,
                    beam_seq, beam_prob, emit_ta, loop_state)
Exemplo n.º 47
0
        def body(
            time,
            outputs_ta,
            state,
            inputs,
            finished,
            sequence_lengths
            ):
            next_outputs, next_state, next_inputs, decoder_finished = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = tf.logical_or(decoder_finished, finished)   
                next_finished = tf.reshape(next_finished, [-1]) #reshape이유 1: helper에서 cond에 들어가면 merge가 됨, 2: inference시에 2차원 값이 나옴

                
            next_sequence_lengths = tf.where(
                tf.logical_not(finished),
                x= tf.fill(tf.shape(sequence_lengths), time + 1),
                y= sequence_lengths
                )

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            if impute_finished:
                new_linear = nest.map_structure(
                    lambda out, zero: tf.where(finished, zero, out),
                    next_outputs.linear,
                    tf.zeros_like(next_outputs.linear)
                    )
                next_outputs._replace(linear= new_linear)

                def _maybe_copy_state(new, cur):
                    if isinstance(cur, tf.TensorArray):
                        pass_through = True
                    else:
                        new.set_shape(cur.shape)
                        pass_through = (new.shape.ndims == 0)
                    return new if pass_through else tf.where(finished, cur, new)

                next_state = nest.map_structure(_maybe_copy_state, next_state, state)

            outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, next_outputs)

            return time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths
Exemplo n.º 48
0
def _verify_tf_loop_vars(init_vars,
                         iter_entry_vars,
                         iter_exit_vars,
                         symbol_names,
                         opts,
                         check_shapes=True):
    """Verifies loop variables for consistency."""
    if check_shapes and 'shape_invariants' in opts:
        shape_invariants = opts['shape_invariants']
    else:
        shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)

    assert len(symbol_names) == len(shape_invariants)
    assert len(symbol_names) == len(init_vars)
    assert len(symbol_names) == len(iter_entry_vars)
    assert len(symbol_names) == len(iter_exit_vars)

    for i in range(len(symbol_names)):
        name = symbol_names[i]
        init = init_vars[i]
        entry = iter_entry_vars[i]
        exit_ = iter_exit_vars[i]
        invariant = shape_invariants[i]

        try:
            nest.assert_same_structure(init, entry, expand_composites=True)
            nest.assert_same_structure(entry, exit_, expand_composites=True)
        except (ValueError, TypeError) as e:
            raise TypeError(
                "'{}' does not have the same nested structure after one"
                ' iteration.\n\n{}'.format(name, e))
        if invariant is not None:
            try:
                nest.assert_same_structure(init,
                                           invariant,
                                           expand_composites=False)
            except (ValueError, TypeError) as e:
                raise TypeError(
                    "'{}' does not have the same nested structure as its"
                    ' corresponding shape invariant.\n\n{}'.format(name, e))

        nest.map_structure(
            functools.partial(_verify_single_loop_var, name, check_shapes),
            init, entry, exit_, invariant)
Exemplo n.º 49
0
 def testSameStructure(self):
     d = {1: "a"}
     nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
Exemplo n.º 50
0
 def testSameStructure(self):
     l = [1]
     nest.assert_same_structure(l,
                                data_structures._ListWrapper(copy.copy(l)))
Exemplo n.º 51
0
 def testNestAssertSameStructure(self, s1, s2, expand_composites=True):
     nest.assert_same_structure(s1, s2, expand_composites=expand_composites)
     nest.assert_shallow_structure(s1,
                                   s2,
                                   expand_composites=expand_composites)
Exemplo n.º 52
0
    def decoder(self, encoder_state, attn_features, prediction_inputs,
                previous_y):
        """
        :param encoder_state: shape [batch_size, encoder_rnn_depth]
        :param prediction_inputs: features for prediction days, tensor[batch_size, time, input_depth]
        :param previous_y: Last day pageviews, shape [batch_size]
        :param attn_features: Additional features from attention layer, shape [batch, predict_window, readout_depth*n_heads]
        :return: decoder rnn output
        """
        hparams = self.hparams

        def build_cell(idx):
            with tf.variable_scope('decoder_cell',
                                   initializer=self.default_init(idx)):
                cell = rnn.GRUBlockCell(self.hparams.rnn_depth)
                has_dropout = hparams.decoder_input_dropout[idx] < 1 \
                              or hparams.decoder_state_dropout[idx] < 1 or hparams.decoder_output_dropout[idx] < 1

                if self.is_train and has_dropout:
                    attn_depth = attn_features.shape[
                        -1].value if attn_features is not None else 0
                    input_size = attn_depth + prediction_inputs.shape[
                        -1].value + 1 if idx == 0 else self.hparams.rnn_depth
                    cell = rnn.DropoutWrapper(
                        cell,
                        dtype=tf.float32,
                        input_size=input_size,
                        variational_recurrent=hparams.
                        decoder_variational_dropout[idx],
                        input_keep_prob=hparams.decoder_input_dropout[idx],
                        output_keep_prob=hparams.decoder_output_dropout[idx],
                        state_keep_prob=hparams.decoder_state_dropout[idx],
                        seed=self.seed + idx)
                return cell

        if hparams.decoder_rnn_layers > 1:
            cells = [
                build_cell(idx) for idx in range(hparams.decoder_rnn_layers)
            ]
            cell = rnn.MultiRNNCell(cells)
        else:
            cell = build_cell(0)

        nest.assert_same_structure(encoder_state, cell.state_size)
        predict_days = self.inp.predict_window
        assert prediction_inputs.shape[1] == predict_days

        # [batch_size, time, input_depth] -> [time, batch_size, input_depth]
        inputs_by_time = tf.transpose(prediction_inputs, [1, 0, 2])

        # Return raw outputs for RNN losses calculation
        return_raw_outputs = self.hparams.decoder_stability_loss > 0.0 or self.hparams.decoder_activation_loss > 0.0

        # Stop condition for decoding loop
        def cond_fn(time, prev_output, prev_state,
                    array_targets: tf.TensorArray,
                    array_outputs: tf.TensorArray):
            return time < predict_days

        # FC projecting layer to get single predicted value from RNN output
        def project_output(tensor):
            return tf.layers.dense(tensor,
                                   1,
                                   name='decoder_output_proj',
                                   kernel_initializer=self.default_init())

        def loop_fn(time, prev_output, prev_state,
                    array_targets: tf.TensorArray,
                    array_outputs: tf.TensorArray):
            """
            Main decoder loop
            :param time: Day number
            :param prev_output: Output(prediction) from previous step
            :param prev_state: RNN state tensor from previous step
            :param array_targets: Predictions, each step will append new value to this array
            :param array_outputs: Raw RNN outputs (for regularization losses)
            :return:
            """
            # RNN inputs for current step
            features = inputs_by_time[time]

            # [batch, predict_window, readout_depth * n_heads] -> [batch, readout_depth * n_heads]
            if attn_features is not None:
                #  [batch_size, 1] + [batch_size, input_depth]
                attn = attn_features[:, time, :]
                # Append previous predicted value + attention vector to input features
                next_input = tf.concat([prev_output, features, attn], axis=1)
            else:
                # Append previous predicted value to input features
                next_input = tf.concat([prev_output, features], axis=1)

            # Run RNN cell
            output, state = cell(next_input, prev_state)
            # Make prediction from RNN outputs
            projected_output = project_output(output)
            # Append step results to the buffer arrays
            if return_raw_outputs:
                array_outputs = array_outputs.write(time, output)
            array_targets = array_targets.write(time, projected_output)
            # Increment time and return
            return time + 1, projected_output, state, array_targets, array_outputs

        # Initial values for loop
        loop_init = [
            tf.constant(0, dtype=tf.int32),
            tf.expand_dims(previous_y, -1), encoder_state,
            tf.TensorArray(dtype=tf.float32, size=predict_days),
            tf.TensorArray(dtype=tf.float32, size=predict_days)
            if return_raw_outputs else tf.constant(0)
        ]
        # Run the loop
        _, _, _, targets_ta, outputs_ta = tf.while_loop(
            cond_fn, loop_fn, loop_init)

        # Get final tensors from buffer arrays
        targets = targets_ta.stack()
        # [time, batch_size, 1] -> [time, batch_size]
        targets = tf.squeeze(targets, axis=-1)
        raw_outputs = outputs_ta.stack() if return_raw_outputs else None
        return targets, raw_outputs
Exemplo n.º 53
0
def raw_rnn(cell,
            loop_fn,
            parallel_iterations=None,
            swap_memory=False,
            scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    if not _like_rnncell(cell):
        raise TypeError("cell must be an instance of RNNCell")
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if not context.executing_eagerly():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None else
                      constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [
                emit.shape
                if emit.shape.is_fully_defined() else array_ops.shape(emit)
                for emit in flat_emit_structure
            ]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [
            s.shape if s.shape.is_fully_defined() else array_ops.shape(s)
            for s in flat_state
        ]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i)
            for i, (dtype_i,
                    size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                        flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
        ]

        zero_emit = nest.pack_sequence_as(structure=emit_structure,
                                          flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i)
            for i, (
                dtype_i,
                size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state,
                                         flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta,
                 state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i,
                                               cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
                                         emit_ta, emit_output)
            state_ta = nest.map_structure(
                lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                time, elements_finished, next_input, state_ta, emit_ta, state,
                loop_state
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states
        ]
        states = nest.pack_sequence_as(structure=state_ta,
                                       flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs
        ]
        outputs = nest.pack_sequence_as(structure=emit_ta,
                                        flat_sequence=flat_outputs)

        return (states, outputs, final_state)
Exemplo n.º 54
0
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
                time, inputs, state
            )
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_not(finished),
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
                sequence_lengths,
            )

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs,
                    zero_outputs,
                )
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = new.shape.ndims == 0
                return new if pass_through else array_ops.where(finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit
            )
            return (
                time + 1,
                outputs_ta,
                next_state,
                next_inputs,
                next_finished,
                next_sequence_lengths,
            )
Exemplo n.º 55
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Exemplo n.º 56
0
def rnn_step(
    time, sequence_length, min_sequence_length, max_sequence_length,
    zero_output, state, call_cell, state_size, skip_conditionals=False):
  """Calculate one step of a dynamic RNN minibatch.

  Returns an (output, state) pair conditioned on the sequence_lengths.
  When skip_conditionals=False, the pseudocode is something like:

  if t >= max_sequence_length:
    return (zero_output, state)
  if t < min_sequence_length:
    return call_cell()

  # Selectively output zeros or output, old state or new state depending
  # on if we've finished calculating each row.
  new_output, new_state = call_cell()
  final_output = np.vstack([
    zero_output if time >= sequence_lengths[r] else new_output_r
    for r, new_output_r in enumerate(new_output)
  ])
  final_state = np.vstack([
    state[r] if time >= sequence_lengths[r] else new_state_r
    for r, new_state_r in enumerate(new_state)
  ])
  return (final_output, final_state)

  Args:
    time: Python int, the current time step
    sequence_length: int32 `Tensor` vector of size [batch_size]
    min_sequence_length: int32 `Tensor` scalar, min of sequence_length
    max_sequence_length: int32 `Tensor` scalar, max of sequence_length
    zero_output: `Tensor` vector of shape [output_size]
    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
      or a list/tuple of such tensors.
    call_cell: lambda returning tuple of (new_output, new_state) where
      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
    state_size: The `cell.state_size` associated with the state.
    skip_conditionals: Python bool, whether to skip using the conditional
      calculations.  This is useful for `dynamic_rnn`, where the input tensor
      matches `max_sequence_length`, and using conditionals just slows
      everything down.

  Returns:
    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
      final_output is a `Tensor` matrix of shape [batch_size, output_size]
      final_state is either a single `Tensor` matrix, or a tuple of such
        matrices (matching length and shapes of input `state`).

  Raises:
    ValueError: If the cell returns a state tuple whose length does not match
      that returned by `state_size`.
  """

  # Convert state to a list for ease of use
  flat_state = nest.flatten(state)
  flat_zero_output = nest.flatten(zero_output)

  def _copy_one_through(output, new_output):
    # If the state contains a scalar value we simply pass it through.
    if output.shape.ndims == 0:
      return new_output
    copy_cond = (time >= sequence_length)
    with ops.colocate_with(new_output):
      return array_ops.where(copy_cond, output, new_output)

  def _copy_some_through(flat_new_output, flat_new_state):
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    flat_new_output = [
        _copy_one_through(zero_output, new_output)
        for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
    flat_new_state = [
        _copy_one_through(state, new_state)
        for state, new_state in zip(flat_state, flat_new_state)]
    return flat_new_output + flat_new_state

  def _maybe_copy_some_through():
    """Run RNN step.  Pass through either no or some past state."""
    new_output, new_state = call_cell()

    nest.assert_same_structure(state, new_state)

    flat_new_state = nest.flatten(new_state)
    flat_new_output = nest.flatten(new_output)
    return control_flow_ops.cond(
        # if t < min_seq_len: calculate and return everything
        time < min_sequence_length, lambda: flat_new_output + flat_new_state,
        # else copy some of it through
        lambda: _copy_some_through(flat_new_output, flat_new_state))

  # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
  # but benefits from removing cond() and its gradient.  We should
  # profile with and without this switch here.
  if skip_conditionals:
    # Instead of using conditionals, perform the selective copy at all time
    # steps.  This is faster when max_seq_len is equal to the number of unrolls
    # (which is typical for dynamic_rnn).
    new_output, new_state = call_cell()
    nest.assert_same_structure(state, new_state)
    new_state = nest.flatten(new_state)
    new_output = nest.flatten(new_output)
    final_output_and_state = _copy_some_through(new_output, new_state)
  else:
    empty_update = lambda: flat_zero_output + flat_state
    final_output_and_state = control_flow_ops.cond(
        # if t >= max_seq_len: copy all state through, output zeros
        time >= max_sequence_length, empty_update,
        # otherwise calculation is required: copy some or all of it through
        _maybe_copy_some_through)

  if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
    raise ValueError("Internal error: state and output were not concatenated "
                     "correctly.")
  final_output = final_output_and_state[:len(flat_zero_output)]
  final_state = final_output_and_state[len(flat_zero_output):]

  for output, flat_output in zip(final_output, flat_zero_output):
    output.set_shape(flat_output.get_shape())
  for substate, flat_substate in zip(final_state, flat_state):
    substate.set_shape(flat_substate.get_shape())

  final_output = nest.pack_sequence_as(
      structure=zero_output, flat_sequence=final_output)
  final_state = nest.pack_sequence_as(
      structure=state, flat_sequence=final_state)

  return final_output, final_state
Exemplo n.º 57
0
 def _assert_sructures_equal(self, struct1, struct2):
     tf_nest.assert_same_structure(struct1, struct2)
     for a, b in zip(tf_nest.flatten(struct1), tf_nest.flatten(struct2)):
         np.testing.assert_array_equal(a, b)
Exemplo n.º 58
0
    def testAssertSameStructure(self):
        structure1 = (((1, 2), 3), 4, (5, 6))
        structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
        structure_different_num_elements = ("spam", "eggs")
        structure_different_nesting = (((1, 2), 3), 4, 5, (6, ))
        nest.assert_same_structure(structure1, structure2)
        nest.assert_same_structure("abc", 1.0)
        nest.assert_same_structure("abc", np.array([0, 1]))
        nest.assert_same_structure("abc", constant_op.constant([0, 1]))

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same number of elements"):
            nest.assert_same_structure(structure1,
                                       structure_different_num_elements)

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same number of elements"):
            nest.assert_same_structure([0, 1], np.array([0, 1]))

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same number of elements"):
            nest.assert_same_structure(0, [0, 1])

        self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
                          [0, 1])

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same nested structure"):
            nest.assert_same_structure(structure1, structure_different_nesting)

        named_type_0 = collections.namedtuple("named_0", ("a", "b"))
        named_type_1 = collections.namedtuple("named_1", ("a", "b"))
        self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
                          named_type_0("a", "b"))

        nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))

        self.assertRaises(TypeError, nest.assert_same_structure,
                          named_type_0(3, 4), named_type_1(3, 4))

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same nested structure"):
            nest.assert_same_structure(named_type_0(3, 4), named_type_0([3],
                                                                        4))

        with self.assertRaisesRegexp(ValueError,
                                     "don't have the same nested structure"):
            nest.assert_same_structure([[3], 4], [3, [4]])
Exemplo n.º 59
0
    def _testWithMaybeMultiAttention(self,
                                     is_multi,
                                     create_attention_mechanisms,
                                     expected_final_output,
                                     expected_final_state,
                                     attention_mechanism_depths,
                                     alignment_history=False,
                                     expected_final_alignment_history=None,
                                     attention_layer_sizes=None,
                                     attention_layers=None,
                                     create_query_layer=False,
                                     create_memory_layer=True,
                                     create_attention_kwargs=None):
        # Allow is_multi to be True with a single mechanism to enable test for
        # passing in a single mechanism in a list.
        assert len(create_attention_mechanisms) == 1 or is_multi
        encoder_sequence_length = [3, 2, 3, 1, 1]
        decoder_sequence_length = [2, 0, 1, 2, 3]
        batch_size = 5
        encoder_max_time = 8
        decoder_max_time = 4
        input_depth = 7
        encoder_output_depth = 10
        cell_depth = 9
        create_attention_kwargs = create_attention_kwargs or {}

        if attention_layer_sizes is not None:
            # Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
            attention_depth = sum(
                attention_layer_size or encoder_output_depth
                for attention_layer_size in attention_layer_sizes)
        elif attention_layers is not None:
            # Compute sum of attention_layers output depth.
            attention_depth = sum(
                attention_layer.compute_output_shape(
                    [batch_size, cell_depth +
                     encoder_output_depth]).dims[-1].value
                for attention_layer in attention_layers)
        else:
            attention_depth = encoder_output_depth * len(
                create_attention_mechanisms)

        decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                         input_depth).astype(np.float32)
        encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                          encoder_output_depth).astype(
                                              np.float32)

        attention_mechanisms = []
        for creator, depth in zip(create_attention_mechanisms,
                                  attention_mechanism_depths):
            # Create a memory layer with deterministic initializer to avoid randomness
            # in the test between graph and eager.
            if create_query_layer:
                create_attention_kwargs["query_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)
            if create_memory_layer:
                create_attention_kwargs["memory_layer"] = keras.layers.Dense(
                    depth, kernel_initializer="ones", use_bias=False)

            attention_mechanisms.append(
                creator(units=depth,
                        memory=encoder_outputs,
                        memory_sequence_length=encoder_sequence_length,
                        **create_attention_kwargs))

        with self.cached_session(use_gpu=True):
            attention_layer_size = attention_layer_sizes
            attention_layer = attention_layers
            if not is_multi:
                if attention_layer_size is not None:
                    attention_layer_size = attention_layer_size[0]
                if attention_layer is not None:
                    attention_layer = attention_layer[0]
            cell = keras.layers.LSTMCell(cell_depth,
                                         recurrent_activation="sigmoid",
                                         kernel_initializer="ones",
                                         recurrent_initializer="ones")
            cell = wrapper.AttentionWrapper(
                cell,
                attention_mechanisms if is_multi else attention_mechanisms[0],
                attention_layer_size=attention_layer_size,
                alignment_history=alignment_history,
                attention_layer=attention_layer)
            if cell._attention_layers is not None:
                for layer in cell._attention_layers:
                    if getattr(layer, "kernel_initializer") is None:
                        layer.kernel_initializer = initializers.glorot_uniform(
                            seed=1337)

            sampler = sampler_py.TrainingSampler()
            my_decoder = basic_decoder.BasicDecoderV2(cell=cell,
                                                      sampler=sampler)
            initial_state = cell.get_initial_state(dtype=dtypes.float32,
                                                   batch_size=batch_size)
            final_outputs, final_state, _ = my_decoder(
                decoder_inputs,
                initial_state=initial_state,
                sequence_length=decoder_sequence_length)

            self.assertIsInstance(final_outputs,
                                  basic_decoder.BasicDecoderOutput)
            self.assertIsInstance(final_state, wrapper.AttentionWrapperState)

            expected_time = (expected_final_state.time
                             if context.executing_eagerly() else None)
            self.assertEqual(
                (batch_size, expected_time, attention_depth),
                tuple(final_outputs.rnn_output.get_shape().as_list()))
            self.assertEqual(
                (batch_size, expected_time),
                tuple(final_outputs.sample_id.get_shape().as_list()))

            self.assertEqual(
                (batch_size, attention_depth),
                tuple(final_state.attention.get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[0].get_shape().as_list()))
            self.assertEqual(
                (batch_size, cell_depth),
                tuple(final_state.cell_state[1].get_shape().as_list()))

            if alignment_history:
                if is_multi:
                    state_alignment_history = []
                    for history_array in final_state.alignment_history:
                        history = history_array.stack()
                        self.assertEqual(
                            (expected_time, batch_size, encoder_max_time),
                            tuple(history.get_shape().as_list()))
                        state_alignment_history.append(history)
                    state_alignment_history = tuple(state_alignment_history)
                else:
                    state_alignment_history = final_state.alignment_history.stack(
                    )
                    self.assertEqual(
                        (expected_time, batch_size, encoder_max_time),
                        tuple(state_alignment_history.get_shape().as_list()))
                nest.assert_same_structure(
                    cell.state_size, cell.zero_state(batch_size,
                                                     dtypes.float32))
                # Remove the history from final_state for purposes of the
                # remainder of the tests.
                final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
            else:
                state_alignment_history = ()

            self.evaluate(variables.global_variables_initializer())
            eval_result = self.evaluate({
                "final_outputs":
                final_outputs,
                "final_state":
                final_state,
                "state_alignment_history":
                state_alignment_history,
            })

            final_output_info = nest.map_structure(
                get_result_summary, eval_result["final_outputs"])
            final_state_info = nest.map_structure(get_result_summary,
                                                  eval_result["final_state"])
            print("final_output_info: ", final_output_info)
            print("final_state_info: ", final_state_info)

            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_output, final_output_info)
            nest.map_structure(self.assertAllCloseOrEqual,
                               expected_final_state, final_state_info)
            if alignment_history:  # by default, the wrapper emits attention as output
                final_alignment_history_info = nest.map_structure(
                    get_result_summary, eval_result["state_alignment_history"])
                print("final_alignment_history_info: ",
                      final_alignment_history_info)
                nest.map_structure(
                    self.assertAllCloseOrEqual,
                    # outputs are batch major but the stacked TensorArray is time major
                    expected_final_alignment_history,
                    final_alignment_history_info)
Exemplo n.º 60
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
    """Like tf.while_loop, except emits a single While op."""
    maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
    # Keep the original loop_vars around to know which args were TensorArrays.
    orig_loop_vars = loop_vars
    # Cache its length since we use it at multiple places below.
    len_orig_loop_vars = len(orig_loop_vars)

    # Convert TensorArrays to their flow variables. These get converted back to
    # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
    # `wrapped_body` below.
    loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
    loop_vars = nest.map_structure(
        ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(orig_loop_vars, shape_invariants)
    else:
        shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")

        loop_counter = constant_op.constant(
            0,
            dtype=maximum_iterations.dtype
            if maximum_iterations is not None else None,
            name="loop_counter")
        # Add loop counter needed for computing gradients.
        loop_vars = [loop_counter] + loop_vars

        shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                                   ]) + shape_invariants

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(loop_counter, *args):
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            if maximum_iterations is None:
                return cond(*_pack_sequence_as(orig_loop_vars, args))
            else:
                return math_ops.logical_and(
                    loop_counter < maximum_iterations,
                    cond(*_pack_sequence_as(orig_loop_vars, args)))

        cond_graph = func_graph_module.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileCondFuncGraph(cond_name),
            add_control_dependencies=add_control_dependencies)

        # Add external_captures of cond to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + cond_graph.external_captures
        shape_invariants = shape_invariants + type(shape_invariants)(
            [t.shape for t in cond_graph.external_captures])

        def wrapped_body(loop_counter, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            outputs = body(
                *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
            if not nest.is_sequence(outputs):
                outputs = [outputs]
            # Compare the structure of input and output of body converting the
            # top-level tuples to list to be compatible with legacy while_loop.
            nest.assert_same_structure(list(outputs), list(orig_loop_vars))

            outputs = _tensor_array_to_flow(outputs)

            # Return the external_captures of cond_graph as is, i.e., treat them as
            # loop invariants.
            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1] + list(outputs) + list(
                args[len_orig_loop_vars:])

        body_graph = func_graph_module.func_graph_from_py_func(
            body_name,
            wrapped_body,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileBodyFuncGraph(body_name),
            add_control_dependencies=add_control_dependencies)
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture `external_captures` of `body_graph` in `cond_graph` so that it
        # expects to receive those as arguments.
        # TODO(b/118457764): Dedup tensors that are captured in both the cond and
        # body. This logic already exists in cond_v2.
        with cond_graph.as_default():
            for external_capture in body_graph.external_captures:
                assert external_capture not in cond_graph.captures, (
                    "Looks like both cond and body are capturing the same tensor %s. "
                    "This is not supported yet. For now consider passing,"
                    " this as a loop variable." % str(external_capture))
                cond_graph.capture(external_capture)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        num_flattened_outputs = len(nest.flatten(orig_loop_vars))
        _check_shapes_compat(
            body_graph.outputs[1:1 + num_flattened_outputs],
            nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
            nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
        flattened_loop_vars = nest.flatten(loop_vars)
        _check_num_inputs_outputs(cond_graph, body_graph,
                                  len(flattened_loop_vars))

        outputs = gen_functional_ops._while(
            flattened_loop_vars,
            util.create_new_tf_function(cond_graph),
            util.create_new_tf_function(body_graph),
            output_shapes=[t.shape for t in body_graph.outputs],
            name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        util.maybe_set_lowering_attr(outputs[0].op)
        _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

        # Return identities for each output of the While op, rather than the output
        # of the While op directly. This makes pruning work if the output of
        # while_loop() is fetched: the lowering pass converts the While outputs into
        # IdentityN outputs, which if fetched will cause all ops in the body to be
        # run (since it takes all exit ops as input). After lowering, each output
        # identity op will end up with only the appropriate exit op as input.
        outputs = tuple(array_ops.identity(t) for t in outputs)

    # First var is loop counter.
    outputs = _pack_sequence_as(orig_loop_vars,
                                outputs[1:1 + num_flattened_outputs])

    if return_same_structure:
        return outputs

    flattened_outputs = nest.flatten(outputs)
    if len(flattened_outputs) == 1:
        return flattened_outputs[0]
    else:
        return outputs