Example #1
0
  def testFlattenAndPack(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(
        nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
                                                 ("d", "e", ("f", "g"), "h")))
    point = collections.namedtuple("Point", ["x", "y"])
    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(nest.flatten(structure), flat)
    restructured_from_flat = nest.pack_sequence_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], nest.flatten(5))
    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

    with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
      nest.pack_sequence_as("scalar", [4, 5])

    with self.assertRaisesRegexp(TypeError, "flat_sequence"):
      nest.pack_sequence_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Example #2
0
def _tpu_run(strategy, fn, args, kwargs):
  """Common implementation of TPUStrategy.experimental_run_v2."""
  if context.executing_eagerly() and not ops.inside_function():
    raise NotImplementedError(
        "Eager mode not supported in TPUStrategy outside TF functions.")

  if kwargs is None:
    kwargs = {}

  # Used to re-structure flattened output tensors from `tpu.replicate()`
  # into a structured format.
  result = [[]]

  def replicated_fn(replica_id, replica_args, replica_kwargs):
    """Wraps user function to provide replica ID and `Tensor` inputs."""
    with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
      result[0] = fn(*replica_args, **replica_kwargs)
    return result[0]

  replicate_inputs = []  # By replica.
  for i in range(strategy.num_replicas_in_sync):
    replicate_inputs.append(
        [constant_op.constant(i, dtype=dtypes.int32),
         values.select_replica(i, args),
         values.select_replica(i, kwargs)])

  # Construct and pass `maximum_shapes` so that we could support dynamic
  # shapes using dynamic padder.
  if replicate_inputs:
    maximum_shapes = []
    flattened_list = nest.flatten(replicate_inputs[0])
    for input_tensor in flattened_list:
      maximum_shapes.append(input_tensor.get_shape())
    maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                           maximum_shapes)
  else:
    maximum_shapes = None

  with strategy.scope():
    replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs,
                                      maximum_shapes=maximum_shapes)

  # Remove all no ops that may have been added during 'tpu.replicate()'
  if isinstance(result[0], list):
    result[0] = [
        output for output in result[0] if tensor_util.is_tensor(output)
    ]

  # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
  replicate_outputs = [
      nest.pack_sequence_as(result[0], nest.flatten(replica_output))
      for replica_output in replicate_outputs
  ]

  device_map = strategy.extended._device_map  # pylint: disable=protected-access
  return values.regroup(device_map, replicate_outputs)
Example #3
0
 def testPackDictOrder(self):
   """Packing orders dicts by key, including OrderedDicts."""
   ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
   plain = {"d": 0, "b": 0, "a": 0, "c": 0}
   seq = [0, 1, 2, 3]
   ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
   plain_reconstruction = nest.pack_sequence_as(plain, seq)
   self.assertEqual(
       collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
       ordered_reconstruction)
   self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
Example #4
0
 def testPackDictOrder(self, mapping_type):
   """Packing orders dicts by key, including OrderedDicts."""
   custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
   plain = {"d": 0, "b": 0, "a": 0, "c": 0}
   seq = [0, 1, 2, 3]
   custom_reconstruction = nest.pack_sequence_as(custom, seq)
   plain_reconstruction = nest.pack_sequence_as(plain, seq)
   self.assertIsInstance(custom_reconstruction, mapping_type)
   self.assertIsInstance(plain_reconstruction, dict)
   self.assertEqual(
       mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
       custom_reconstruction)
   self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
Example #5
0
def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
                       total_time, inputs_lengths, is_reversed):
  """Post-process output of recurrent.

  This function takes the accumulated extended state and extracts the requested
  state and output.

  When `inputs_lengths` has been set, it extracts the output from the
  accumulated state. It also sets outputs past.

  When `is_reversed` is true, the output will be reversed in this function.

  It also sets the static shape information.

  Args:
    extended_acc_state: A structure containing the accumulated state at each
      time. It may contain the output at each time as well.
    extended_final_state: A structure containing the final state. It may
      contain the output at the final time.
    func_cell: The functional wrapper around the cell.
    total_time: A scalar integer tensor.
    inputs_lengths: An integer tensor with one entry per input.
    is_reversed: A boolean to indicate if the sequence is reversed.

  Returns:
    A tuple with the outputs at each time, and the final state.
  """
  if inputs_lengths is None or is_reversed:
    flat_final_state = func_cell.MaybeRemoveOutputFromState(
        nest.flatten(extended_final_state))
    tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
  else:
    # The accumulated state is over the entire sequence, so we pick it
    # out from the acc_state sequence.
    flat_acc_state = func_cell.MaybeRemoveOutputFromState(
        nest.flatten(extended_acc_state))
    acc_state = nest.pack_sequence_as(
        func_cell.state_template, flat_acc_state)
    tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)

  output_from_state = func_cell.GetOutputFromState(extended_acc_state)
  if is_reversed:
    output_from_state = array_ops.reverse(output_from_state, [0])
  tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
  tf_output.set_shape(
      [func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
  if inputs_lengths is not None:
    # Need set the outputs to zero.
    tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
  _SetShapeFromTemplate(tf_state, func_cell.state_template)
  return tf_output, tf_state
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_result = fn(ctx, iterator.get_next())
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    # TODO(priyag): Use max_iterations instead of an explicit counter.
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Example #7
0
  def decorated(*args, **kwds):
    """Computes the value and gradient of the decorated function."""
    parameter_positions = _get_arg_spec(f, params, args)
    assert not kwds, "The gradient function can't take keyword arguments."
    this_tape = tape.push_new_tape(persistent=persistent)
    try:
      sources = []
      args = [
          ops.convert_to_tensor(args[i])
          if i in parameter_positions else args[i]
          for i in range(len(args))
      ]
      args = _ensure_unique_tensor_objects(parameter_positions, args)
      for i in parameter_positions:
        sources.append(args[i])
        tape.watch(this_tape, args[i])
      result = f(*args)
      if result is None:
        raise ValueError("Cannot differentiate a function that returns None; "
                         "did you forget to return a value from {}?".format(
                             f.__name__))
      flat_result = nest.flatten(result)
      flat_result = [gen_array_ops.identity(x) for x in flat_result]
      result = nest.pack_sequence_as(result, flat_result)
    finally:
      tape.pop_tape(this_tape)
    def vjp(dy=None):
      if dy is not None:
        dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
      return imperative_grad.imperative_grad(
          this_tape, nest.flatten(result), sources, output_gradients=dy)

    return result, vjp
Example #8
0
  def _build_call_outputs(self, result):
    """Maps the fdef output list to actual output structure.

    Args:
      result: Output lists defined by FunctionDef.
    Returns:
      The actual call output.
    """
    if self._func_outputs is None:
      return None
    # Use `nest.flatten` instead of `_flatten` in order to preserve any
    # IndexedSlices in `self._func_outputs`.
    outputs_list = nest.flatten(self._func_outputs)
    j = 0
    for i, o in enumerate(outputs_list):
      if o is not None:
        if isinstance(o, ops.IndexedSlices):
          # Repack Tensors for IndexedSlices.
          if o.dense_shape is not None:
            outputs_list[i] = ops.IndexedSlices(
                values=result[j],
                indices=result[j + 1],
                dense_shape=result[j + 2])
            j += 3
          else:
            outputs_list[i] = ops.IndexedSlices(
                values=result[j],
                indices=result[j + 1])
            j += 2
        else:
          outputs_list[i] = result[j]
          j += 1
    ret = nest.pack_sequence_as(self._func_outputs, outputs_list)
    return ret
Example #9
0
  def testRegularizers(self, trainable, state_size):
    batch_size = 6

    # Set the attribute to the class since it we can't set properties of
    # abstract classes
    snt.RNNCore.state_size = state_size
    flat_state_size = nest.flatten(state_size)
    core = snt.RNNCore(name="dummy_core")
    flat_regularizer = ([tf.contrib.layers.l1_regularizer(scale=0.5)] *
                        len(flat_state_size))
    trainable_regularizers = nest.pack_sequence_as(
        structure=state_size, flat_sequence=flat_regularizer)

    core.initial_state(batch_size, dtype=tf.float32, trainable=trainable,
                       trainable_regularizers=trainable_regularizers)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    if not trainable:
      self.assertFalse(graph_regularizers)
    else:
      self.assertEqual(len(graph_regularizers), len(flat_state_size))
      if not tf.executing_eagerly():
        for i in range(len(flat_state_size)):
          self.assertRegexpMatches(
              graph_regularizers[i].name, ".*l1_regularizer.*")
Example #10
0
def _Pack(elements, struct_template):
  """Packs the list of tensors according to the structure.

  In the event that `elements` should be a scalar, `struct_template` must
  contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`).

  Args:
    elements: Elements to be packed. A list of tensor, or a single tensor.
    struct_template: The container structure in which to pack them.
  Returns:
    A python structure of the same type as `struct_template`, containing
    `elements` as its contained elements.
  """
  if not nest.is_sequence(elements):
    return nest.pack_sequence_as(struct_template, [elements])
  return nest.pack_sequence_as(struct_template, elements)
Example #11
0
    def BackwardLoopBody(*args):
      """Backward loop body function."""
      t, dev_t = args[0], args[1]
      (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1,
       d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig)

      # The input recurrent state for time step t is previous time step's
      # output, or the original state0 when on time step 0.
      state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1))
      state0 = functional_ops.If(
          math_ops.equal(t, array_ops.constant(0, dtypes.int32)),
          _Flatten([state_from_acc, orig_state0]), ReturnOrigState0,
          ReturnAccState)
      state0 = nest.pack_sequence_as(orig_state0, state0)

      # The external inputs for time step t.
      inputs_t = _Index(inputs, t)
      # The extras for time step t.
      extras_t = _Index(acc_extras, t)

      d_state1 = _Add(_Index(d_acc_state, t), d_state1)
      (d_theta_t, d_state0, d_inputs_t) = _Pack(
          Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
          [self._theta, self._state, self._inputs])
      d_theta = _Add(d_theta, d_theta_t)
      d_inputs = _Update(d_inputs, d_inputs_t, dev_t)
      return [math_ops.subtract(dev_t, 1)] + _Flatten([
          theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0,
          d_inputs, d_acc_state
      ])
Example #12
0
def _pack_dict(dct, lst):
  d, keys = Struct(), sorted(dct)
  for k in keys:
    n = len(nest.flatten(dct[k]))
    d[k] = nest.pack_sequence_as(dct[k], lst[:n])
    lst = lst[n:]
  return d
Example #13
0
def _Update(struct_acc, struct_x, t):
  """Updates t-th row in accumulators.

  Args:
    struct_acc: The accumulators. A structure of tensors.
    struct_x: The new values. A structure of tensors congruent to `struct_acc`.
    t: A scalar integer. Performance is better if `t` is on the device
      memory.

  Returns:
    A structure of tensors. Say, ret is a returned dictionary. Then, for
    each key, we have:
      ret[key] = struct_acc[key];
      ret[key][t, :] = struct_x[key]
  """
  to_skip_update = set()
  acc_lst = nest.flatten(struct_acc)
  x_lst = nest.flatten(struct_x)
  t = math_ops.to_int32([t])  # tf.to_int32 casts on-device tensors.
  lst = []
  for acc, x in zip(acc_lst, x_lst):
    if acc in to_skip_update:
      # Until b/62105730 is fixed, we need to avoid inplace update for tensors
      # of rank 1.  could reshape to handle it, but we don't really need the
      # values applied to these, so just skip their modification.
      lst += [acc]
    else:
      lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
  return nest.pack_sequence_as(struct_acc, lst)
Example #14
0
  def _time_step(time, next_input, output_ta_t, state):
    for input_, shape in zip(next_input, inputs_got_shape):
      input_.set_shape(shape)

    
    input_t = nest.pack_sequence_as(structure=input, flat_sequence=next_input)
    # import tensorflow as tf
    # input_t = tf.Print(input_t, [input_t, time], "input_t, time")

    call_cell = lambda: cell(input_t, state)

    if sequence_length is not None:
      (output, new_state) = _rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=zero_output,
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()

    # Pack state if using state tuples
    output = nest.flatten(output)

    output_ta_t = tuple(
        ta.write(time, out) for ta, out in zip(output_ta_t, output))

    return (time + 1, output, output_ta_t, new_state)
  def loop_fn(i):
    sequence_length_i = array_ops.gather(sequence_length, i)

    def body_fn(t, state, ta):
      inputs_t = array_ops.expand_dims(
          array_ops.gather(inputs_ta.read(t), i), 0)
      output, new_state = cell(inputs_t, state)
      output = array_ops.reshape(output, [-1])
      # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
      # array_ops.where when t < min(sequence_length). Doing that requires
      # supporting tf.cond pfor conversion.
      done = t >= sequence_length_i
      output = array_ops.where(done, zeros, output)
      ta = ta.write(t, output)
      new_state = [array_ops.where(done, s, ns) for s, ns in
                   zip(nest.flatten(state), nest.flatten(new_state))]
      new_state = nest.pack_sequence_as(state, new_state)
      return t + 1, new_state, ta

    def condition_fn(t, _, unused):
      del unused
      return t < max_steps

    initial_state = cell.zero_state(1, dtypes.float32)
    _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
        0, initial_state,
        tensor_array_ops.TensorArray(dtypes.float32, max_steps)
    ])

    new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
    new_state = nest.pack_sequence_as(initial_state, new_state)
    return ta.stack(), new_state
Example #16
0
def _deep_copy_state(complex_state_tuple, name):
    flat_state = nest.flatten(complex_state_tuple)
    flat_state_copy = [tf.identity(s, name="{}_{}".format(name, i))
                       for i, s in enumerate(flat_state)]
    state_copy = nest.pack_sequence_as(complex_state_tuple,
                                       flat_state_copy)
    return state_copy
  def restore(self, file_prefix):
    """Restore the saveable objects from a checkpoint with `file_prefix`.

    Args:
      file_prefix: A string or scalar string Tensor containing the prefix for
        files to read from.

    Returns:
      A dictionary mapping from SaveableObject names to restore operations.
    """
    restore_specs = []
    tensor_structure = []
    for saveable in self._saveable_objects:
      saveable_tensor_structure = []
      tensor_structure.append(saveable_tensor_structure)
      for spec in saveable.specs:
        saveable_tensor_structure.append(spec.name)
        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
    with ops.device("cpu:0"):
      restored_tensors = io_ops.restore_v2(
          file_prefix, tensor_names, tensor_slices, tensor_dtypes)
    structured_restored_tensors = nest.pack_sequence_as(
        tensor_structure, restored_tensors)
    restore_ops = {}
    for saveable, restored_tensors in zip(self._saveable_objects,
                                          structured_restored_tensors):
      restore_ops[saveable.name] = saveable.restore(
          restored_tensors, restored_shapes=None)
    return restore_ops
  def get_next(self, name=None):
    """See `tf.data.Iterator.get_next`."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_result = []
    # TODO(priyag): This will fail if the input size (typically number of
    # batches) is not divisible by number of devices.
    # How do we handle that more gracefully / let the user know?
    for buffer_resource in self._buffering_resources:
      flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
          buffer_resource,
          output_types=data_nest.flatten(sparse.as_dense_types(
              self.output_types, self.output_classes)), name=name)

      ret = sparse.deserialize_sparse_tensors(
          data_nest.pack_sequence_as(self.output_types, flat_ret),
          self.output_types, self.output_shapes, self.output_classes)

      for tensor, shape in zip(
          data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
        if isinstance(tensor, ops.Tensor):
          tensor.set_shape(shape)
      flat_result.append(ret)

    return nest.pack_sequence_as(self._devices, flat_result)
def _concrete_function_callable_with(function, inputs, allow_conversion):
  """Returns whether concrete `function` can be called with `inputs`."""
  expected_structure = function.graph.structured_input_signature
  try:
    flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
  except (TypeError, ValueError):
    return False
  try:
    # Verify that no input elements were dropped during flattening.
    repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
    # TODO(b/129422719): Namedtuple subclasses re-created through
    # saved_model.load don't compare equal in type to the original in
    # assert_same_structure. Fix that and we can take out check_types=False
    # here.
    nest.assert_same_structure(inputs, repacked, check_types=False)
  except (TypeError, ValueError):
    return False

  for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
    if isinstance(expected, tensor_spec.TensorSpec):
      if allow_conversion:
        arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
      if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
        return False
      if arg.dtype != expected.dtype:
        return False
      if not expected.shape.is_compatible_with(arg.shape):
        return False
    else:
      if arg != expected:
        return False
  return True
Example #20
0
 def _get_grads_lists_empirical(self, tensors):
   grads_flat = gradients_impl.gradients(
       self._layers.total_loss(),
       nest.flatten(tensors),
       colocate_gradients_with_ops=self._colocate_gradients_with_ops)
   grads_all = nest.pack_sequence_as(tensors, grads_flat)
   return tuple((grad,) for grad in grads_all)
Example #21
0
    def _create(self, encoder_output, decoder_state_size, **kwargs):
        """ Creates decoder's initial RNN states according to
        `decoder_state_size`.

        Creates a tf variable and passes to decoder.
        Args:
            encoder_output: An instance of `collections.namedtuple`
              from `Encoder.encode()`.
            decoder_state_size: RNN decoder state size.
            **kwargs:

        Returns: The decoder states with the structure determined
          by `decoder_state_size`.
        """
        batch_size = tf.shape(encoder_output.attention_length)[0]
        name = kwargs["name"] if "name" in kwargs else None
        state_size_splits = nest.flatten(decoder_state_size)
        total_decoder_state_size = sum(state_size_splits)
        with tf.variable_scope(name or "init_state"):
            init_state_total = tf.get_variable(
                name="init_states", shape=(total_decoder_state_size,),
                dtype=tf.float32, initializer=tf.zeros_initializer)
        init_state_total = tf.tile([init_state_total], [batch_size, 1])
        init_state = nest.pack_sequence_as(
            decoder_state_size,
            tf.split(init_state_total, state_size_splits, axis=1))
        return init_state
Example #22
0
  def _get_cached_states(self, times):
    """Retrieve cached states for a batch of times."""
    read_chunk_numbers = self._get_chunk_number(times)
    looked_up_state = list(self._cached_states.lookup(
        math_ops.cast(read_chunk_numbers, dtypes.int64)))
    looked_up_state = tuple(looked_up_state)
    # We need to special-case the first chunk in a series to explicitly rely on
    # the model's starting state so that gradients flow back to it. Otherwise it
    # would affect only initialization, and would not be read from or updated
    # during training. Not doing this also isolates that part of the graph,
    # leading to errors on model reload if there are trainable variables
    # affecting a model's start state.
    if self._input_statistics is not None:
      start_time = self._input_statistics.start_time
    else:
      start_time = 0
    set_to_start_state = math_ops.equal(read_chunk_numbers,
                                        self._get_chunk_number(start_time))
    new_states = []
    for start_state_value, cache_variable in zip(
        nest.flatten(
            math_utils.replicate_state(self._start_state,
                                       array_ops.shape(times)[0])),
        nest.flatten(looked_up_state)):

      new_states.append(
          array_ops.where(set_to_start_state, start_state_value,
                          cache_variable))
    looked_up_state = nest.pack_sequence_as(looked_up_state, new_states)
    return looked_up_state
Example #23
0
 def _apply_exogenous_update(
     self, current_times, step_number, state, raw_features,
     embedded_exogenous_regressors):
   """Performs a conditional state update based on exogenous features."""
   if embedded_exogenous_regressors is None:
     return state
   else:
     current_exogenous_regressors = embedded_exogenous_regressors[
         :, step_number, :]
     exogenous_updated_state = self._exogenous_input_step(
         current_times=current_times,
         current_exogenous_regressors=current_exogenous_regressors,
         state=state)
     if self._exogenous_update_condition is not None:
       current_raw_exogenous_features = {
           key: value[:, step_number] for key, value in raw_features.items()
           if key not in [PredictionFeatures.STATE_TUPLE,
                          TrainEvalFeatures.TIMES,
                          TrainEvalFeatures.VALUES]}
       conditionally_updated_state_flat = []
       for updated_state_element, original_state_element in zip(
           nest.flatten(exogenous_updated_state),
           nest.flatten(state)):
         conditionally_updated_state_flat.append(
             array_ops.where(
                 self._exogenous_update_condition(
                     times=current_times,
                     features=current_raw_exogenous_features),
                 updated_state_element,
                 original_state_element))
       return nest.pack_sequence_as(state, conditionally_updated_state_flat)
     else:
       return exogenous_updated_state
Example #24
0
  def zero_state(self, batch_size, dtype):
    """Return zero-filled state tensor(s).

    Args:
      batch_size: int, float, or unit Tensor representing the batch size.
      dtype: the data type to use for the state.

    Returns:
      If `state_size` is an int or TensorShape, then the return value is a
      `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

      If `state_size` is a nested list or tuple, then the return value is
      a nested list or tuple (of the same structure) of `2-D` tensors with
    the shapes `[batch_size x s]` for each s in `state_size`.
    """
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
              dtype=dtype)
          for s in state_size_flat]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros
Example #25
0
        def body(time, elements_finished, current_input, emit_ta, state, loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
                next_time, next_output, cell_state, loop_state
            )

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(
                        lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
                    )
                    for (current_i, candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]

            emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
  """Runs `loop_fn` `iters` times and stacks the outputs.


  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
  stacks corresponding outputs of the different runs.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and returns a possibly nested structure of tensor
      objects. The shape of these outputs should not depend on the input.
    loop_fn_dtypes: dtypes for the outputs of loop_fn.
    iters: Number of iterations for which to run loop_fn.
    parallel_iterations: The number of iterations that can be dispatched in
      parallel. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked output tensor objects with the same
    nested structure as the output of `loop_fn`.
  """

  flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
  is_none_list = []

  def while_body(i, *ta_list):
    """Body of while loop."""
    fn_output = nest.flatten(loop_fn(i))
    if len(fn_output) != len(flat_loop_fn_dtypes):
      raise ValueError(
          "Number of expected outputs, %d, does not match the number of "
          "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
                                                len(fn_output)))
    outputs = []
    del is_none_list[:]
    is_none_list.extend([x is None for x in fn_output])
    for out, ta in zip(fn_output, ta_list):
      # TODO(agarwal): support returning Operation objects from loop_fn.
      if out is not None:
        # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
        ta = ta.write(i, array_ops.expand_dims(out, 0))
      outputs.append(ta)
    return tuple([i + 1] + outputs)

  if parallel_iterations is not None:
    extra_args = {"parallel_iterations": parallel_iterations}
  else:
    extra_args = {}
  ta_list = control_flow_ops.while_loop(
      lambda i, *ta: i < iters,
      while_body,
      [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
             for dtype in flat_loop_fn_dtypes],
      **extra_args)[1:]

  # TODO(rachelim): enable this for sparse tensors

  output = [None if is_none else ta.concat()
            for ta, is_none in zip(ta_list, is_none_list)]
  return nest.pack_sequence_as(loop_fn_dtypes, output)
 def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
   all_ops = ops.get_collection("iterator_ops")
   if sparse_tensors:
     init_op, indices, values, dense_shape = all_ops
     return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
   else:
     return all_ops[0], nest.pack_sequence_as(
         self._get_output_types(ds_fn), all_ops[1:])
 def testNestPackSequenceAs(self,
                            structure,
                            sequence,
                            expected,
                            expand_composites=True):
   result = nest.pack_sequence_as(
       structure, sequence, expand_composites=expand_composites)
   self.assertEqual(result, expected)
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    if initial_loop_values is None:
      initial_loop_values = {}
    initial_loop_values = nest.flatten(initial_loop_values)

    ctx = values.MultiStepContext()
    def body(i, *args):
      """A wrapper around `fn` to create the while loop body."""
      del args
      fn_inputs = iterator.get_next()
      if not isinstance(fn_inputs, tuple):
        fn_inputs = (fn_inputs,)
      fn_result = fn(ctx, *fn_inputs)
      for (name, output) in ctx.last_step_outputs.items():
        # Convert all outputs to tensors, potentially from `DistributedValues`.
        ctx.last_step_outputs[name] = self.unwrap(output)
      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
      with ops.control_dependencies([fn_result]):
        return [i + 1] + flat_last_step_outputs

    # We capture the control_flow_context at this point, before we run `fn`
    # inside a while_loop. This is useful in cases where we might need to exit
    # these contexts and get back to the outer context to do some things, for
    # e.g. create an op which should be evaluated only once at the end of the
    # loop on the host. One such usage is in creating metrics' value op.
    self._outer_control_flow_context = (
        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

    cond = lambda i, *args: i < iterations
    i = constant_op.constant(0)
    loop_result = control_flow_ops.while_loop(
        cond, body, [i] + initial_loop_values, name="",
        parallel_iterations=1, back_prop=False, swap_memory=False,
        return_same_structure=True)
    del self._outer_control_flow_context

    ctx.run_op = control_flow_ops.group(loop_result)

    # Convert the last_step_outputs from a list to the original dict structure
    # of last_step_outputs.
    last_step_tensor_outputs = loop_result[1:]
    last_step_tensor_outputs_dict = nest.pack_sequence_as(
        ctx.last_step_outputs, last_step_tensor_outputs)

    for (name, aggregation) in ctx._last_step_outputs_aggregations.items():  # pylint: disable=protected-access
      output = last_step_tensor_outputs_dict[name]
      # For outputs that have already been aggregated, wrap them in a Mirrored
      # container, else in a PerDevice container.
      if aggregation is variables_lib.VariableAggregation.NONE:
        last_step_tensor_outputs_dict[name] = values.regroup(
            {d: t for d, t in zip(self._devices, output)}, values.PerDevice)
      else:
        assert len(output) == 1
        last_step_tensor_outputs_dict[name] = output[0]

    ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
    return ctx
Example #30
0
 def _get_grads_lists_empirical(self, tensors):
   # Passing in a list of loss values is better than passing in the sum as
   # the latter creates unnessesary ops on the default device
   grads_flat = gradients_impl.gradients(
       self._layers.eval_losses(),
       nest.flatten(tensors),
       colocate_gradients_with_ops=self._colocate_gradients_with_ops)
   grads_all = nest.pack_sequence_as(tensors, grads_flat)
   return tuple((grad,) for grad in grads_all)
Example #31
0
 def input_pack(x):
     return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
Example #32
0
def static_rnn(cell, inputs, initial_state=None, dtype=None,
               sequence_length=None, scope=None, fg=None):
  """Creates a recurrent neural network specified by RNNCell `cell`.
  The simplest form of RNN network generated is:
  ```python
    state = cell.zero_state(...)
    outputs = []
    for input_ in inputs:
      output, state = cell(input_, state)
      outputs.append(output)
    return (outputs, state)
  ```
  However, a few other options are available:
  An initial state can be provided.
  If the sequence_length vector is provided, dynamic calculation is performed.
  This method of calculation does not compute the RNN steps past the maximum
  sequence length of the minibatch (thus saving computational time),
  and properly propagates the state at an example's sequence length
  to the final state output.
  The dynamic calculation performed is, at time `t` for batch row `b`,
  ```python
    (output, state)(b, t) =
      (t >= sequence_length(b))
        ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
        : cell(input(b, t), state(b, t - 1))
  ```
  Args:
    cell: An instance of RNNCell.
    inputs: A length T list of inputs, each a `Tensor` of shape
      `[batch_size, input_size]`, or a nested tuple of such elements.
    initial_state: (optional) An initial state for the RNN.
      If `cell.state_size` is an integer, this must be
      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
      If `cell.state_size` is a tuple, this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    sequence_length: Specifies the length of each sequence in inputs.
      An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
    scope: VariableScope for the created subgraph; defaults to "rnn".
  Returns:
    A pair (outputs, state) where:
    - outputs is a length T list of outputs (one for each input), or a nested
      tuple of such elements.
    - state is the final state
  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If `inputs` is `None` or an empty list, or if the input depth
      (column size) cannot be inferred from inputs via shape inference.
  """

  # if not isinstance(cell, core_rnn_cell.RNNCell):
  #   raise TypeError("cell must be an instance of RNNCell")
  # if not nest.is_sequence(inputs):
  #   raise TypeError("inputs must be a sequence")
  # if not inputs:
  #   raise ValueError("inputs must not be empty")

  outputs = []
  # Create a new scope in which the caching device is either
  # determined by the parent scope, or is set to place the cached
  # Variable using the same placement as for the rest of the RNN.
  with vs.variable_scope(scope or "rnn") as varscope:
    if varscope.caching_device is None:
      varscope.set_caching_device(lambda op: op.device)

    # Obtain the first sequence of the input
    first_input = inputs
    while nest.is_sequence(first_input):
      first_input = first_input[0]

    # Temporarily avoid EmbeddingWrapper and seq2seq badness
    # TODO(lukaszkaiser): remove EmbeddingWrapper
    if first_input.get_shape().ndims != 1:

      input_shape = first_input.get_shape().with_rank_at_least(2)
      fixed_batch_size = input_shape[0]

      flat_inputs = nest.flatten(inputs)
      for flat_input in flat_inputs:
        input_shape = flat_input.get_shape().with_rank_at_least(2)
        batch_size, input_size = input_shape[0], input_shape[1:]
        fixed_batch_size.merge_with(batch_size)
        for i, size in enumerate(input_size):
          if size.value is None:
            raise ValueError(
                "Input size (dimension %d of inputs) must be accessible via "
                "shape inference, but saw value None." % i)
    else:
      fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0]

    if fixed_batch_size.value:
      batch_size = fixed_batch_size.value
    else:
      batch_size = array_ops.shape(first_input)[0]
    if initial_state is not None:
      state = initial_state
    else:
      if not dtype:
        raise ValueError("If no initial_state is provided, "
                         "dtype must be specified")
      state = cell.zero_state(batch_size, dtype)

    if sequence_length is not None:  # Prepare variables
      sequence_length = ops.convert_to_tensor(
          sequence_length, name="sequence_length")
      if sequence_length.get_shape().ndims not in (None, 1):
        raise ValueError(
            "sequence_length must be a vector of length batch_size")
      def _create_zero_output(output_size):
        # convert int to TensorShape if necessary
        size = _state_size_with_prefix(output_size, prefix=[batch_size])
        output = array_ops.zeros(
            array_ops.stack(size), _infer_state_dtype(dtype, state))
        shape = _state_size_with_prefix(
            output_size, prefix=[fixed_batch_size.value])
        output.set_shape(tensor_shape.TensorShape(shape))
        return output

      output_size = cell.output_size
      flat_output_size = nest.flatten(output_size)
      flat_zero_output = tuple(
          _create_zero_output(size) for size in flat_output_size)
      zero_output = nest.pack_sequence_as(structure=output_size,
                                          flat_sequence=flat_zero_output)

      sequence_length = math_ops.to_int32(sequence_length)
      min_sequence_length = math_ops.reduce_min(sequence_length)
      max_sequence_length = math_ops.reduce_max(sequence_length)

    for time, input_ in enumerate(inputs):
      if time > 0: varscope.reuse_variables()
      # pylint: disable=cell-var-from-loop
      #if fg: call_cell = lambda: cell(input_, state)#.apply(fg, state)
      #else : call_cell = lambda: cell(input_, state)
      call_cell = lambda: cell(input_, state)
      # pylint: enable=cell-var-from-loop
      if sequence_length is not None:
        (output, state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            min_sequence_length=min_sequence_length,
            max_sequence_length=max_sequence_length,
            zero_output=zero_output,
            state=state,
            call_cell=call_cell,
            state_size=cell.state_size)
      else:
        (output, state) = call_cell()

      outputs.append(output)

    return (outputs, state)
Example #33
0
def raw_rnn(cell,
            loop_fn,
            parallel_iterations=None,
            swap_memory=False,
            scope=None):
    """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.

  **NOTE: This method is still in testing, and the API may change.**

  This function is a more primitive version of `dynamic_rnn` that provides
  more direct access to the inputs each iteration.  It also provides more
  control over when to start and finish reading the sequence, and
  what to emit for the output.

  For example, it can be used to implement the dynamic decoder of a seq2seq
  model.

  Instead of working with `Tensor` objects, most operations work with
  `TensorArray` objects directly.

  The operation of `raw_rnn`, in pseudo-code, is basically the following:

  ```python
  time = tf.constant(0, dtype=tf.int32)
  (finished, next_input, initial_state, _, loop_state) = loop_fn(
      time=time, cell_output=None, cell_state=None, loop_state=None)
  emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
  state = initial_state
  while not all(finished):
    (output, cell_state) = cell(next_input, state)
    (next_finished, next_input, next_state, emit, loop_state) = loop_fn(
        time=time + 1, cell_output=output, cell_state=cell_state,
        loop_state=loop_state)
    # Emit zeros and copy forward state for minibatch entries that are finished.
    state = tf.where(finished, state, next_state)
    emit = tf.where(finished, tf.zeros_like(emit), emit)
    emit_ta = emit_ta.write(time, emit)
    # If any new minibatch entries are marked as finished, mark these.
    finished = tf.logical_or(finished, next_finished)
    time += 1
  return (emit_ta, state, loop_state)
  ```

  with the additional properties that output and state may be (possibly nested)
  tuples, as determined by `cell.output_size` and `cell.state_size`, and
  as a result the final `state` and `emit_ta` may themselves be tuples.

  A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:

  ```python
  inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
                          dtype=tf.float32)
  sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
  inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
  inputs_ta = inputs_ta.unstack(inputs)

  cell = tf.contrib.rnn.LSTMCell(num_units)

  def loop_fn(time, cell_output, cell_state, loop_state):
    emit_output = cell_output  # == None for time == 0
    if cell_output is None:  # time == 0
      next_cell_state = cell.zero_state(batch_size, tf.float32)
    else:
      next_cell_state = cell_state
    elements_finished = (time >= sequence_length)
    finished = tf.reduce_all(elements_finished)
    next_input = tf.cond(
        finished,
        lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
        lambda: inputs_ta.read(time))
    next_loop_state = None
    return (elements_finished, next_input, next_cell_state,
            emit_output, next_loop_state)

  outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
  outputs = outputs_ta.stack()
  ```

  Args:
    cell: An instance of RNNCell.
    loop_fn: A callable that takes inputs
      `(time, cell_output, cell_state, loop_state)`
      and returns the tuple
      `(finished, next_input, next_cell_state, emit_output, next_loop_state)`.
      Here `time` is an int32 scalar `Tensor`, `cell_output` is a
      `Tensor` or (possibly nested) tuple of tensors as determined by
      `cell.output_size`, and `cell_state` is a `Tensor`
      or (possibly nested) tuple of tensors, as determined by the `loop_fn`
      on its first call (and should match `cell.state_size`).
      The outputs are: `finished`, a boolean `Tensor` of
      shape `[batch_size]`, `next_input`: the next input to feed to `cell`,
      `next_cell_state`: the next state to feed to `cell`,
      and `emit_output`: the output to store for this iteration.

      Note that `emit_output` should be a `Tensor` or (possibly nested)
      tuple of tensors with shapes and structure matching `cell.output_size`
      and `cell_output` above.  The parameter `cell_state` and output
      `next_cell_state` may be either a single or (possibly nested) tuple
      of tensors.  The parameter `loop_state` and
      output `next_loop_state` may be either a single or (possibly nested) tuple
      of `Tensor` and `TensorArray` objects.  This last parameter
      may be ignored by `loop_fn` and the return value may be `None`.  If it
      is not `None`, then the `loop_state` will be propagated through the RNN
      loop, for use purely by `loop_fn` to keep track of its own state.
      The `next_loop_state` parameter returned may be `None`.

      The first call to `loop_fn` will be `time = 0`, `cell_output = None`,
      `cell_state = None`, and `loop_state = None`.  For this call:
      The `next_cell_state` value should be the value with which to initialize
      the cell's state.  It may be a final state from a previous RNN or it
      may be the output of `cell.zero_state()`.  It should be a
      (possibly nested) tuple structure of tensors.
      If `cell.state_size` is an integer, this must be
      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
      If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of
      appropriate type and shape `[batch_size] + cell.state_size`.
      If `cell.state_size` is a (possibly nested) tuple of ints or
      `TensorShape`, this will be a tuple having the corresponding shapes.
      The `emit_output` value may be  either `None` or a (possibly nested)
      tuple structure of tensors, e.g.,
      `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`.
      If this first `emit_output` return value is `None`,
      then the `emit_ta` result of `raw_rnn` will have the same structure and
      dtypes as `cell.output_size`.  Otherwise `emit_ta` will have the same
      structure, shapes (prepended with a `batch_size` dimension), and dtypes
      as `emit_output`.  The actual values returned for `emit_output` at this
      initializing call are ignored.  Note, this emit structure must be
      consistent across all time steps.

    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency
      and can be run in parallel, will be.  This parameter trades off
      time for space.  Values >> 1 use more memory but take less time,
      while smaller values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs
      which would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A tuple `(emit_ta, final_state, final_loop_state)` where:

    `emit_ta`: The RNN output `TensorArray`.
       If `loop_fn` returns a (possibly nested) set of Tensors for
       `emit_output` during initialization, (inputs `time = 0`,
       `cell_output = None`, and `loop_state = None`), then `emit_ta` will
       have the same structure, dtypes, and shapes as `emit_output` instead.
       If `loop_fn` returns `emit_output = None` during this call,
       the structure of `cell.output_size` is used:
       If `cell.output_size` is a (possibly nested) tuple of integers
       or `TensorShape` objects, then `emit_ta` will be a tuple having the
       same structure as `cell.output_size`, containing TensorArrays whose
       elements' shapes correspond to the shape data in `cell.output_size`.

    `final_state`: The final cell state.  If `cell.state_size` is an int, this
      will be shaped `[batch_size, cell.state_size]`.  If it is a
      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
      be a tuple having the corresponding shapes.

    `final_loop_state`: The final loop state as returned by `loop_fn`.

  Raises:
    TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not
      a `callable`.
  """

    # pylint: disable=protected-access
    if not isinstance(cell, rnn_cell_impl._RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    # pylint: enable=protected-access
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input,
         initial_state, emit_structure, init_loop_state) = loop_fn(
             time, None, None,
             None)  # time, cell_output, cell_state, loop_state
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None else
                      constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [emit.get_shape() for emit in flat_emit_structure]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_emit_ta = [
            tensor_array_ops.TensorArray(dtype=dtype_i,
                                         dynamic_size=True,
                                         size=0,
                                         name="rnn_output_%d" % i)
            for i, dtype_i in enumerate(flat_emit_dtypes)
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                        flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(
                _state_size_with_prefix(size_i, prefix=[batch_size]), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
        ]
        zero_emit = nest.pack_sequence_as(structure=emit_structure,
                                          flat_sequence=flat_zero_emit)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, emit_ta, state,
                 loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(lambda: array_ops.where(elements_finished,
                                                       current_i, candidate_i),
                               device=candidate_i.op.device)
                    for (current_i,
                         candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current,
                                             flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            emit_ta_flat = [
                ta.write(time, emit)
                for (ta, emit) in zip(emit_ta_flat, emit_output_flat)
            ]

            emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                            flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta,
                    next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                time, elements_finished, next_input, emit_ta, state, loop_state
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        (emit_ta, final_state, final_loop_state) = returned[-3:]

        if init_loop_state is None:
            final_loop_state = None

        return (emit_ta, final_state, final_loop_state)
def nest_map(func, nested):
    if not nest.is_sequence(nested):
        return func(nested)
    flat = nest.flatten(nested)
    return nest.pack_sequence_as(nested, list(map(func, flat)))
Example #35
0
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
    """Runs `loop_fn` `iters` times and stacks the outputs.


  Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
  stacks corresponding outputs of the different runs.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and returns a possibly nested structure of tensor
      objects. The shape of these outputs should not depend on the input.
    loop_fn_dtypes: dtypes for the outputs of `loop_fn`.
    iters: Number of iterations for which to run `loop_fn`.
    parallel_iterations: The number of iterations that can be dispatched in
      parallel. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked output tensor objects with the same
    nested structure as the output of `loop_fn`.
  """

    flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
    is_none_list = []

    def while_body(i, *ta_list):
        """Body of while loop."""
        fn_output = nest.flatten(loop_fn(i))
        if len(fn_output) != len(flat_loop_fn_dtypes):
            raise ValueError(
                "Number of expected outputs, %d, does not match the number of "
                "actual outputs, %d, from loop_fn" %
                (len(flat_loop_fn_dtypes), len(fn_output)))
        outputs = []
        del is_none_list[:]
        is_none_list.extend(x is None for x in fn_output)
        for out, ta in zip(fn_output, ta_list):
            # TODO(agarwal): support returning Operation objects from loop_fn.
            if out is not None:
                # out may be a ref tensor, wrap it in identity to get a non-ref tensor.
                ta = ta.write(i, array_ops.expand_dims(out, 0))
            outputs.append(ta)
        return tuple([i + 1] + outputs)

    if parallel_iterations is not None:
        extra_args = {"parallel_iterations": parallel_iterations}
    else:
        extra_args = {}
    ta_list = control_flow_ops.while_loop(
        lambda i, *ta: i < iters, while_body, [0] + [
            tensor_array_ops.TensorArray(dtype.base_dtype, iters)
            for dtype in flat_loop_fn_dtypes
        ], **extra_args)[1:]

    # TODO(rachelim): enable this for sparse tensors

    output = [
        None if is_none else ta.concat()
        for ta, is_none in zip(ta_list, is_none_list)
    ]
    assert len(output) in (0, len(flat_loop_fn_dtypes))
    if not output:
        # This may happen for the case where iters == 0.
        return None
    else:
        return nest.pack_sequence_as(loop_fn_dtypes, output)
Example #36
0
 def lookup(self, keys):
     return nest.pack_sequence_as(self._hash_tables, [
         hash_table.lookup(keys)
         for hash_table in nest.flatten(self._hash_tables)
     ])
Example #37
0
def _compile_internal(computation, inputs=None):
  """Builds graph operators that compiles and symbolically executes computation.

  Args:
    computation: A Python function that builds the computation to compile and
      execute.
    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
      can be a nested structure containing values that are convertible to
      tensors. Note that passing an N-dimension list of compatible values will
      result in a N-dimension list of scalar tensors rather than a single Rank-N
      tensors. If you need different behavior, convert part of inputs to tensors
      with `tf.convert_to_tensor`.

  Returns:
    Same data structure as if computation(*inputs) is called directly with some
    exceptions for correctness. Exceptions include: 1) None output 2) Single
    value output 3) Operation-only outputs
  Raises:
    ValueError: If any element in computation outputs is neither an operations
      or a value that can be converted to tensor.
    ValueError: If computation outputs is non-flat and contains any Operations.
    TypeError: If `inputs` is not a list or tuple.
  """
  if inputs is None:
    inputs = []

  if not isinstance(inputs, collections.Sequence):
    raise TypeError('inputs must be a list')

  # Flatten inputs.
  flat_inputs = nest.flatten(inputs)
  # Converts inputs to Tensors.
  flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]

  cluster_name = ops.get_default_graph().unique_name('cluster')
  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
  context = XLACompileContext(name=cluster_name, pivot=pivot)
  try:
    context.Enter()

    # Add identity ops so even unused inputs are 'consumed' by the
    # computation.
    flat_inputs = [
        array_ops.identity(x, name='input_{}'.format(i))
        for i, x in enumerate(flat_inputs)
    ]

    # Re-pack flat_inputs in same structure as 'inputs'.
    computation_inputs = nest.pack_sequence_as(
        structure=inputs, flat_sequence=flat_inputs)

    # Only resource variables work inside an XLA computation, so turn on
    # resource variables for the computation.
    vscope = variable_scope.get_variable_scope()
    saved_use_resource = vscope.use_resource
    vscope.set_use_resource(True)

    with _disable_summary_context():
      outputs = computation(*computation_inputs)

    # Restore variable scope after computation.
    vscope.set_use_resource(saved_use_resource)

    outputs_is_flat = is_flat(outputs)
    if outputs_is_flat:
      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
    else:
      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)

    context.ExitResult(output_tensors)
  finally:
    context.report_unsupported_operations()
    context.Exit()

  # When XLA computation returns only operations and no tensors, a NoOp
  # dependent on the operations in outputs is returned. Otherwise final
  # outputs would be empty and there is no way to trigger returned
  # operations.
  if not output_tensors:
    return control_flow_ops.group(control_deps, name='output_0')

  output_tensors = [
      xla_ops.xla_cluster_output(o, name='output{}'.format(i))
      for i, o in enumerate(output_tensors)
  ]

  with ops.control_dependencies(control_deps):
    # Wraps the outputs in identity operators that carries control
    # dependencies.
    output_tensors = [
        array_ops.identity(o, name='output_%d' % i)
        for i, o in enumerate(output_tensors)
    ]

  # If `computation` returned non-flat output structure, pack output tensors
  # back into same structure.
  if not outputs_is_flat:
    output_tensors = nest.pack_sequence_as(
        structure=outputs, flat_sequence=output_tensors)

  return output_tensors
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()
    args = nest.map_structure(ops.convert_to_tensor, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)
    args = nest.flatten(args)
    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    inputs = args
    variables_in_tape = frozenset([
        v.ref() for v in variable_watcher.watched_variables()
    ]) - frozenset(v.ref() for v in inputs)
    variables_in_subgraph = frozenset([
        v.ref()
        for v in get_dependent_variables(input_ops=inputs, output_ops=result)
    ])
    variables = list(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError("If using @custom_gradient with a function that "
                        "uses variables, then grad_fn must accept a keyword "
                        "argument 'variables'.")
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warn(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
Example #39
0
def slicing_where(condition, full_input, true_branch, false_branch):
    """Split 'full_input' between 'true_branch' and 'false_branch' on 'condition'.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between 'true_branch' and
      'false_branch' based on 'condition'.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as 'full_input'. Receives
      slices of 'full_input' corresponding to the True entries of
      'condition'. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like 'true_branch', but receives inputs corresponding to the
      false elements of 'condition'. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of 'true_branch'), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from 'true_branch' and 'false_branch', each Tensor
    having shape [B_1, ..., B_N, ...].
  """
    full_input_flat = nest.flatten(full_input)
    true_indices = tf.where(condition)
    false_indices = tf.where(tf.logical_not(condition))
    true_branch_inputs = nest.pack_sequence_as(
        structure=full_input,
        flat_sequence=[
            tf.gather_nd(params=input_tensor, indices=true_indices)
            for input_tensor in full_input_flat
        ])
    false_branch_inputs = nest.pack_sequence_as(
        structure=full_input,
        flat_sequence=[
            tf.gather_nd(params=input_tensor, indices=false_indices)
            for input_tensor in full_input_flat
        ])
    true_outputs = true_branch(true_branch_inputs)
    false_outputs = false_branch(false_branch_inputs)
    nest.assert_same_structure(true_outputs, false_outputs)

    def scatter_outputs(true_output, false_output):
        batch_shape = tf.shape(condition)
        scattered_shape = tf.concat(
            [batch_shape,
             tf.shape(true_output)[tf.rank(batch_shape):]], 0)
        true_scatter = tf.scatter_nd(indices=tf.cast(true_indices, tf.int32),
                                     updates=true_output,
                                     shape=scattered_shape)
        false_scatter = tf.scatter_nd(indices=tf.cast(false_indices, tf.int32),
                                      updates=false_output,
                                      shape=scattered_shape)
        return true_scatter + false_scatter

    result = nest.pack_sequence_as(
        structure=true_outputs,
        flat_sequence=[
            scatter_outputs(true_single_output, false_single_output)
            for true_single_output, false_single_output in zip(
                nest.flatten(true_outputs), nest.flatten(false_outputs))
        ])
    return result
Example #40
0
def _rnn_step(time,
              sequence_length,
              min_sequence_length,
              max_sequence_length,
              zero_output,
              state,
              call_cell,
              state_size,
              skip_conditionals=False):
    """Calculate one step of a dynamic RNN minibatch.

  Returns an (output, state) pair conditioned on the sequence_lengths.
  When skip_conditionals=False, the pseudocode is something like:

  if t >= max_sequence_length:
    return (zero_output, state)
  if t < min_sequence_length:
    return call_cell()

  # Selectively output zeros or output, old state or new state depending
  # on if we've finished calculating each row.
  new_output, new_state = call_cell()
  final_output = np.vstack([
    zero_output if time >= sequence_lengths[r] else new_output_r
    for r, new_output_r in enumerate(new_output)
  ])
  final_state = np.vstack([
    state[r] if time >= sequence_lengths[r] else new_state_r
    for r, new_state_r in enumerate(new_state)
  ])
  return (final_output, final_state)

  Args:
    time: Python int, the current time step
    sequence_length: int32 `Tensor` vector of size [batch_size]
    min_sequence_length: int32 `Tensor` scalar, min of sequence_length
    max_sequence_length: int32 `Tensor` scalar, max of sequence_length
    zero_output: `Tensor` vector of shape [output_size]
    state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
      or a list/tuple of such tensors.
    call_cell: lambda returning tuple of (new_output, new_state) where
      new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
      new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
    state_size: The `cell.state_size` associated with the state.
    skip_conditionals: Python bool, whether to skip using the conditional
      calculations.  This is useful for `dynamic_rnn`, where the input tensor
      matches `max_sequence_length`, and using conditionals just slows
      everything down.

  Returns:
    A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
      final_output is a `Tensor` matrix of shape [batch_size, output_size]
      final_state is either a single `Tensor` matrix, or a tuple of such
        matrices (matching length and shapes of input `state`).

  Raises:
    ValueError: If the cell returns a state tuple whose length does not match
      that returned by `state_size`.
  """

    # Convert state to a list for ease of use
    flat_state = nest.flatten(state)
    flat_zero_output = nest.flatten(zero_output)

    def _copy_one_through(output, new_output):
        copy_cond = (time >= sequence_length)
        return _on_device(
            lambda: array_ops.where(copy_cond, output, new_output),
            device=new_output.op.device)

    def _copy_some_through(flat_new_output, flat_new_state):
        # Use broadcasting select to determine which values should get
        # the previous state & zero output, and which values should get
        # a calculated state & output.
        flat_new_output = [
            _copy_one_through(zero_output,
                              new_output) for zero_output, new_output in zip(
                                  flat_zero_output, flat_new_output)
        ]
        flat_new_state = [
            _copy_one_through(state, new_state)
            for state, new_state in zip(flat_state, flat_new_state)
        ]
        return flat_new_output + flat_new_state

    def _maybe_copy_some_through():
        """Run RNN step.  Pass through either no or some past state."""
        new_output, new_state = call_cell()

        nest.assert_same_structure(state, new_state)

        flat_new_state = nest.flatten(new_state)
        flat_new_output = nest.flatten(new_output)
        return control_flow_ops.cond(
            # if t < min_seq_len: calculate and return everything
            time < min_sequence_length,
            lambda: flat_new_output + flat_new_state,
            # else copy some of it through
            lambda: _copy_some_through(flat_new_output, flat_new_state))

    # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
    # but benefits from removing cond() and its gradient.  We should
    # profile with and without this switch here.
    if skip_conditionals:
        # Instead of using conditionals, perform the selective copy at all time
        # steps.  This is faster when max_seq_len is equal to the number of unrolls
        # (which is typical for dynamic_rnn).
        new_output, new_state = call_cell()
        nest.assert_same_structure(state, new_state)
        new_state = nest.flatten(new_state)
        new_output = nest.flatten(new_output)
        final_output_and_state = _copy_some_through(new_output, new_state)
    else:
        empty_update = lambda: flat_zero_output + flat_state
        final_output_and_state = control_flow_ops.cond(
            # if t >= max_seq_len: copy all state through, output zeros
            time >= max_sequence_length,
            empty_update,
            # otherwise calculation is required: copy some or all of it through
            _maybe_copy_some_through)

    if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
        raise ValueError(
            "Internal error: state and output were not concatenated "
            "correctly.")
    final_output = final_output_and_state[:len(flat_zero_output)]
    final_state = final_output_and_state[len(flat_zero_output):]

    for output, flat_output in zip(final_output, flat_zero_output):
        output.set_shape(flat_output.get_shape())
    for substate, flat_substate in zip(final_state, flat_state):
        substate.set_shape(flat_substate.get_shape())

    final_output = nest.pack_sequence_as(structure=zero_output,
                                         flat_sequence=final_output)
    final_state = nest.pack_sequence_as(structure=state,
                                        flat_sequence=final_state)

    return final_output, final_state
Example #41
0
 def dequeue_fn():
   dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
   return nest.pack_sequence_as(iterator.output_shapes, dequeued)
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
  """Like tf.while_loop, except emits a single While op."""
  flattened_loop_vars = nest.flatten(loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(loop_vars, shape_invariants)
    flattened_shapes = nest.flatten(shape_invariants)
  else:
    flattened_shapes = [t.shape for t in flattened_loop_vars]

  del shape_invariants

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
      body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))

    num_outputs = len(flattened_loop_vars)

    # Add loop counter needed for computing gradients.
    flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
                          ] + flattened_loop_vars

    flattened_shapes = [tensor_shape.scalar()] + flattened_shapes

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(unused_loop_counter, *loop_vars):
      return cond(*loop_vars)

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    cond_graph = function.func_graph_from_py_func(
        cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature)

    # Add external_captures of cond to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
    flattened_shapes = flattened_shapes + [
        t.shape for t in cond_graph.external_captures
    ]

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:num_outputs] - Args for the original loop body.
          args[num_outputs:] - External captures of cond. These get passed
            through as is.

      Returns:
        A list of tensors the same length as args.
      """
      outputs = body(*args[:num_outputs])
      if not isinstance(outputs, collections.Sequence):
        outputs = [outputs]

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    body_graph = function.func_graph_from_py_func(
        body_name, wrapped_body, flattened_loop_vars, {}, signature=signature)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
    # expects to receive those as arguments.
    # TODO(srbs): Dedup tensors that are captured in both the cond and body.
    # This logic already exists in cond_v2.
    with cond_graph.as_default():
      for external_capture in body_graph.external_captures:
        cond_graph.capture(external_capture)

    # Export all tensors in the loop body that may be needed for gradient
    # computation. We do this by accumulating the intermediate values in
    # TensorLists.
    intermediate_tensors = _get_intermediates(body_graph)

    for intermediate_tensor in intermediate_tensors:
      # TODO(srbs): Cache and re-use empty tensor lists.
      tensor_list = list_ops.empty_tensor_list(
          element_dtype=intermediate_tensor.dtype,
          element_shape=_get_tensor_convertible_shape(
              intermediate_tensor.shape))
      flattened_loop_vars.append(tensor_list)
      with cond_graph.as_default():
        # Add a placeholder to cond_graph's inputs corresponding to the
        # tensor_list.
        cond_graph.capture(tensor_list)
      with body_graph.as_default():
        # Push the intermediate tensor to the tensor list. This captures the
        # `tensor_list` as well.
        appended_tensor_list = list_ops.tensor_list_push_back(
            tensor_list,
            intermediate_tensor)
        # Add this modified tensor list to the list of outputs.
        body_graph.outputs.append(appended_tensor_list)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
                         flattened_shapes[1:1 + num_outputs],
                         flattened_loop_vars[1:1 + num_outputs])
    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        cond_v2._create_new_tf_function(cond_graph),
        cond_v2._create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    _maybe_set_lowering_attr(outputs[0].op)

  # First var is loop counter.
  if num_outputs == 1:
    return outputs[1]
  else:
    return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
Example #43
0
    def pack_args(self, args):
        """Packs the given list of arguments into the internal args structure."""

        return nest.pack_sequence_as(self._args, args)
Example #44
0
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the exitence of `handle`."""
    def _feed_fn(feed, feed_val):
      for tensor_type, _, feed_fn, _ in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(feed, tensor_type):
          return feed_fn(feed, feed_val)
      raise TypeError('Feed argument %r has invalid type %r'
                      % (feed, type(feed)))

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Flatten/unflatten fetched values.
    if isinstance(fetches, (list, tuple)):
      # fetches is already a list or tuple; nothing to do.
      orig_fetches, fetches = fetches, nest.flatten(fetches)
      unflatten = lambda fetched: nest.pack_sequence_as(orig_fetches, fetched)
    elif isinstance(fetches, dict):
      # fetches is a dictionary; flatten the values and map fetched
      # values back into to a dictionary.
      # nest.flatten does not accept iterators, next line is for python3
      # compatibility.
      fetches_values = list(fetches.values())
      orig_fetches, fetches = fetches, nest.flatten(fetches_values)
      unflatten = lambda fetched: _unflatten_fetches(orig_fetches, fetched)
    else:
      # fetches is a singleton.
      fetches = [fetches]
      unflatten = lambda fetched: fetched[0]

    # Validate and process fetches.
    processed_fetches = self._process_fetches(fetches)
    unique_fetches = processed_fetches[0]
    target_list = processed_fetches[1]
    fetch_info = processed_fetches[2]
    unique_handles = processed_fetches[3]

    # Create request.
    feed_dict_string = {}
    feed_map = {}

    # Validate and process feed_dict.
    if feed_dict:
      feed_dict = nest.flatten_dict_items(feed_dict)
      for feed, feed_val in feed_dict.items():
        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
          try:
            subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
                                                    allow_operation=False)
          except Exception as e:
            raise TypeError('Cannot interpret feed_dict key as Tensor: '
                            + e.args[0])

          if isinstance(subfeed_val, ops.Tensor):
            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
                            'Acceptable feed values include Python scalars, '
                            'strings, lists, or numpy ndarrays.')

          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
          if isinstance(subfeed_val,
                        int) and subfeed_dtype(subfeed_val) != subfeed_val:
            raise TypeError(
                'Type of feed value ' + str(subfeed_val) + ' is not'
                ' compatible with Tensor type ' + str(subfeed_dtype) + '.'
                ' Try explicitly setting the type of the feed tensor'
                ' to a larger type (e.g. int64).')

          np_val = np.array(subfeed_val, dtype=subfeed_dtype)

          if not subfeed_t.get_shape().is_compatible_with(np_val.shape):
            raise ValueError(
                'Cannot feed value of shape %r for Tensor %r, '
                'which has shape %r'
                % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
          if not self.graph.is_feedable(subfeed_t):
            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
          subfeed_name = compat.as_bytes(subfeed_t.name)
          feed_dict_string[subfeed_name] = np_val
          feed_map[subfeed_name] = (subfeed_t, subfeed_val)

    # Run request and get response.
    # We need to keep the movers alive for the following _do_run().
    # These movers are no longer needed when _do_run() completes, and
    # are deleted when `movers` goes out of scope when this _run() ends.
    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
    # of a handle from a different device as an error.
    movers = self._update_with_movers(feed_dict_string, feed_map)
    results = self._do_run(handle, target_list, unique_fetches,
                           feed_dict_string, options, run_metadata)

    # User may have fetched the same tensor multiple times, but we
    # only fetch them from the runtime once.  Furthermore, they may
    # be wrapped as a tuple of tensors.  Here we map the results back
    # to what the client asked for.
    # TODO(yuanbyu): Use the contraction_fn in _REGISTERED_EXPANSIONS.
    fetched_results = {}
    for fetch, result in zip(unique_fetches, results):
      dtype = unique_handles.get(fetch)
      if dtype:
        result = session_ops.TensorHandle(result, dtype, self)
      fetched_results[fetch] = result
    ret = []
    for fetch_names, fetch_contraction_fn in fetch_info:
      if fetch_names:
        fetched_vals = [fetched_results[name] for name in fetch_names]
        ret.append(fetch_contraction_fn(fetched_vals))
      else:
        ret.append(None)

    return unflatten(ret)
Example #45
0
def _build_signature(loop_vars, shape_invariants):
    return nest.pack_sequence_as(loop_vars, [
        tensor_spec.TensorSpec(s, t.dtype, name=t.op.name) for s, t in zip(
            nest.flatten(shape_invariants), nest.flatten(loop_vars))
    ])
Example #46
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=True,
                scope=None):
    """Creates a recurrent neural network specified by RNNCell `cell`.

  Performs fully dynamic unrolling of `inputs`.

  Example:

  ```python
  # create a BasicRNNCell
  rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)

  # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

  # defining initial state
  initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)

  # 'state' is a tensor of shape [batch_size, cell_state_size]
  outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                     initial_state=initial_state,
                                     dtype=tf.float32)
  ```

  ```python
  # create 2 LSTMCells
  rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]

  # create a RNN cell composed sequentially of a number of RNNCells
  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

  # 'outputs' is a tensor of shape [batch_size, max_time, 256]
  # 'state' is a N-tuple where N is the number of LSTMCells containing a
  # tf.contrib.rnn.LSTMStateTuple for each cell
  outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                     inputs=data,
                                     dtype=tf.float32)
  ```


  Args:
    cell: An instance of RNNCell.
    inputs: The RNN inputs.
      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such elements.
      If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
        batch_size, ...]`, or a nested tuple of such elements. This may also be
        a (possibly nested) tuple of Tensors satisfying this property.  The
        first two dimensions must match across all the inputs, but otherwise the
        ranks and other shape components may differ. In this case, input to
        `cell` at each time-step will replicate the structure of these tuples,
        except for the time dimension (from which the time is taken). The input
        to `cell` at each time step will be a `Tensor` or (possibly nested)
        tuple of Tensors each with dimensions `[batch_size, ...]`.
    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
      to copy-through state and zero-out outputs when past a batch element's
      sequence length.  So it's more for performance than correctness.
    initial_state: (optional) An initial state for the RNN. If `cell.state_size`
      is an integer, this must be a `Tensor` of appropriate type and shape
      `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
      should be a tuple of tensors having shapes `[batch_size, s] for s in
      cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency and
      can be run in parallel, will be.  This parameter trades off time for
      space.  Values >> 1 use more memory but take less time, while smaller
      values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs which
      would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
      these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
      these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
      `time_major = True` is a bit more efficient because it avoids transposes
      at the beginning and end of the RNN calculation.  However, most TensorFlow
      data is batch-major, so by default this function accepts input and emits
      output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A pair (outputs, state) where:

    outputs: The RNN output `Tensor`.

      If time_major == False (default), this will be a `Tensor` shaped:
        `[batch_size, max_time, cell.output_size]`.

      If time_major == True, this will be a `Tensor` shaped:
        `[max_time, batch_size, cell.output_size]`.

      Note, if `cell.output_size` is a (possibly nested) tuple of integers
      or `TensorShape` objects, then `outputs` will be a tuple having the
      same structure as `cell.output_size`, containing Tensors having shapes
      corresponding to the shape data in `cell.output_size`.

    state: The final state.  If `cell.state_size` is an int, this
      will be shaped `[batch_size, cell.state_size]`.  If it is a
      `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
      If it is a (possibly nested) tuple of ints or `TensorShape`, this will
      be a tuple having the corresponding shapes. If cells are `LSTMCells`
      `state` will be a tuple containing a `LSTMStateTuple` for each cell.

  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
    RuntimeError: If not using control flow v2.
  """

    # Currently only support time_major == True case.
    assert time_major

    # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
    # TfLiteRNNCells.
    rnn_cell_impl.assert_like_rnncell("cell", cell)

    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
        raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")

    parent_first_child_input = [{
        "parent_ophint_input_index": 0,
        "first_child_ophint_input_index": 0
    }]
    parent_last_child_output = [{
        "parent_output_index": 0,
        # For LstmCell, the index is 2.
        # For RnnCell, the index is 1.
        # So we use -1 meaning it's the last one.
        "child_output_index": -1
    }]
    internal_children_input_output = [{
        "child_input_index": 0,
        # For LstmCell, the index is 2.
        # For RnnCell, the index is 1.
        # So we use -1 meaning it's the last one.
        "child_output_index": -1
    }]
    inputs_outputs_mappings = {
        "parent_first_child_input": parent_first_child_input,
        "parent_last_child_output": parent_last_child_output,
        "internal_children_input_output": internal_children_input_output
    }
    tflite_wrapper = tf.lite.OpHint(
        "TfLiteDynamicRnn",
        level=2,
        children_inputs_mappings=inputs_outputs_mappings)
    with vs.variable_scope(scope or "rnn") as varscope:
        # Create a new scope in which the caching device is either
        # determined by the parent scope, or is set to place the cached
        # Variable using the same placement as for the rest of the RNN.
        if _should_cache():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        inputs = tflite_wrapper.add_input(inputs,
                                          name="input",
                                          index_override=0)

        # By default, time_major==False and inputs are batch-major: shaped
        #   [batch, time, depth]
        # For internal calculations, we transpose to [time, batch, depth]
        flat_input = nest.flatten(inputs)

        if not time_major:
            # (batch, time, depth) => (time, batch, depth)
            flat_input = [
                ops.convert_to_tensor(input_) for input_ in flat_input
            ]
            flat_input = tuple(
                _transpose_batch_time(input_) for input_ in flat_input)

        parallel_iterations = parallel_iterations or 32
        if sequence_length is not None:
            sequence_length = math_ops.to_int32(sequence_length)
            if sequence_length.get_shape().rank not in (None, 1):
                raise ValueError(
                    "sequence_length must be a vector of length batch_size, "
                    "but saw shape: %s" % sequence_length.get_shape())
            sequence_length = array_ops.identity(  # Just to find it in the graph.
                sequence_length,
                name="sequence_length")

        batch_size = _best_effort_input_batch_size(flat_input)

        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError(
                    "If there is no initial_state, you must give a dtype.")
            if getattr(cell, "get_initial_state", None) is not None:
                state = cell.get_initial_state(inputs=None,
                                               batch_size=batch_size,
                                               dtype=dtype)
            else:
                state = cell.zero_state(batch_size, dtype)

        def _assert_has_shape(x, shape):
            x_shape = array_ops.shape(x)
            packed_shape = array_ops.stack(shape)
            return control_flow_ops.Assert(
                math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
                    "Expected shape for Tensor %s is " % x.name, packed_shape,
                    " but saw shape: ", x_shape
                ])

        if not context.executing_eagerly() and sequence_length is not None:
            # Perform some shape validation
            with ops.control_dependencies(
                [_assert_has_shape(sequence_length, [batch_size])]):
                sequence_length = array_ops.identity(sequence_length,
                                                     name="CheckSeqLen")

        inputs = nest.pack_sequence_as(structure=inputs,
                                       flat_sequence=flat_input)

        outputs, final_state = _dynamic_rnn_loop(
            cell,
            inputs,
            state,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
            sequence_length=sequence_length,
            dtype=dtype)

        # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
        # If we are performing batch-major calculations, transpose output back
        # to shape [batch, time, depth]
        if not time_major:
            # (time, batch, depth) => (batch, time, depth)
            outputs = nest.map_structure(_transpose_batch_time, outputs)
        outputs = tflite_wrapper.add_output(outputs, name="outputs")

        return outputs, final_state
Example #47
0
 def output_pack(x):
     return (nest.pack_sequence_as(dtype, x)
             if output_is_sequence else x[0])
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            multi_worker_iterator,
                                            iterations,
                                            initial_loop_values=None):
        output_shapes = multi_worker_iterator.output_shapes
        shapes = nest.flatten(output_shapes)
        if any(not s.is_fully_defined() for s in shapes):
            raise ValueError(
                "TPU currently requires fully defined shapes. Either use "
                "set_shape() on the input tensors or use "
                "dataset.batch(..., drop_remainder=True).")
        types = nest.flatten(multi_worker_iterator.output_types)

        enqueue_ops = [
            self._get_enqueue_op_per_host(host_id, multi_worker_iterator,
                                          shapes, iterations)
            for host_id in range(self.num_hosts)
        ]

        def dequeue_fn():
            dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types,
                                                    shapes=shapes)
            return nest.pack_sequence_as(output_shapes, dequeued)

        # Wrap `fn` for repeat.
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)
        ctx = input_lib.MultiStepContext()

        def run_fn(*args, **kwargs):
            """Single step on the TPU device."""
            del args, kwargs
            fn_result = fn(ctx, dequeue_fn())
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            if flat_last_step_outputs:
                with ops.control_dependencies([fn_result]):
                    return [
                        array_ops.identity(f) for f in flat_last_step_outputs
                    ]
            else:
                return fn_result

        def iterate_on_tpu():
            return training_loop.repeat(iterations, run_fn,
                                        initial_loop_values)

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop and TPU replicate context. This is useful in cases
        # where we might need to exit these contexts and get back to the outer
        # context to do some things, for e.g. create an op which should be
        # evaluated only once at the end of the loop on the host. One such usage
        # is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        # pylint: disable=protected-access
        if self._container_strategy()._disable_training_loop_on_host:
            replicate_inputs = [[]] * self._num_replicas_in_sync
            replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
        else:

            def rewrite_fn(*args):
                """The rewritten step fn running on TPU."""
                del args
                replicate_inputs = [[]] * self._num_replicas_in_sync
                replicate_outputs = tpu.replicate(run_fn, replicate_inputs)

                # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
                # will flatten it in this case. If run_fn has no tensor outputs,
                # tpu.replicate returns a list of no_ops, we will keep the output as it
                # is.
                if isinstance(replicate_outputs[0], list):
                    replicate_outputs = nest.flatten(replicate_outputs)

                return replicate_outputs

            # TODO(sourabhbajaj): The input to while loop should be based on the
            # output type of the step_fn
            assert isinstance(initial_loop_values, list)
            initial_loop_values = initial_loop_values * self._num_replicas_in_sync

            # Put the while loop op on host 0.
            with ops.device(self.get_host_cpu_device(0)):
                replicate_outputs = training_loop.repeat(
                    iterations, rewrite_fn, initial_loop_values)

        del self._outer_control_flow_context
        ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)

        if self._container_strategy()._disable_training_loop_on_host:
            # Filter out any ops from the outputs, typically this would be the case
            # when there were no tensor outputs.
            last_step_tensor_outputs = [
                x for x in replicate_outputs
                if not isinstance(x, ops.Operation)
            ]

            # Outputs are currently of the structure (grouped by device)
            # [[output0_device0, output1_device0, output2_device0],
            #  [output0_device1, output1_device1, output2_device1]]
            # Convert this to the following structure instead: (grouped by output)
            # [[output0_device0, output0_device1],
            #  [output1_device0, output1_device1],
            #  [output2_device0, output2_device1]]
            last_step_tensor_outputs = [
                list(x) for x in zip(*last_step_tensor_outputs)
            ]
        else:
            if isinstance(replicate_outputs, list):
                # Filter out any ops from the outputs, typically this would be the case
                # when there were no tensor outputs.
                last_step_tensor_outputs = [
                    x for x in replicate_outputs
                    if not isinstance(x, ops.Operation)
                ]

                # Outputs are currently of the structure (flattened)
                # [output0_device0, output1_device0, output2_device0,
                #  output0_device1, output1_device1, output2_device1,
                #  ...]
                # Convert this to the following structure instead: (grouped by output)
                # [[output0_device0, output0_device1],
                #  [output1_device0, output1_device1],
                #  [output2_device0, output2_device1]]
                output_num = len(
                    last_step_tensor_outputs) // self._num_replicas_in_sync
                last_step_tensor_outputs = [
                    last_step_tensor_outputs[i::output_num]
                    for i in range(output_num)
                ]
            else:
                # no tensors returned.
                last_step_tensor_outputs = []

        # Convert replicate_outputs to the original dict structure of
        # last_step_outputs.
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, take the first value
            # from the list as each value should be the same. Else return the full
            # list of values.
            # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
            # value.
            if reduce_op is not None:
                # TODO(priyag): Should this return the element or a list with 1 element
                last_step_tensor_outputs_dict[name] = output[0]
        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access

        return ctx
Example #49
0
def _get_defun_inputs(args, names, structure, flat_shapes=None):
  """Maps python function args to graph-construction inputs.

  Args:
    args: A flat list of user-specified arguments.
    names: A list of strings with user-specified argument names, same length as
      `args`. May be `None`, in which case a generic name is used.
    structure: The original argument list or dictionary.
    flat_shapes: A flat list of values that are either `None` or
      instances of `TensorShape`.  If provided, then length must match
      that of `nest.flatten(args)`; and locations where `args` are
      instances of `Tensor` must have a corresponding `TensorShape` in
      `flat_shapes`.  May be `None`, in which case exact shapes are read
      directly from the args.

  Returns:
    Placeholders with the same structure as `structure`.

  Raises:
    RuntimeError: if `flat_shapes` is provided, but
     `len(flat_shapes) != len(nest.flatten(args))`.
    RuntimeError: if a shape from `flat_shapes` is not None
     for an argument that is not a `Tensor`, `TensorSpec`,
     or `ResourceVariable`.
  """
  func_graph = ops.get_default_graph()
  function_inputs = []
  if names is None:
    names = [None] * len(args)
  if flat_shapes is None:
    shapes_iter = itertools.repeat(None)
  else:
    len_flat_args = len(nest.flatten(args))
    if len_flat_args != len(flat_shapes):
      raise RuntimeError(
          "Length of fully flat shapes (%d) must match that of "
          "flatten(args) (%d).  args: %s, flat_shapes: %s"
          % (len(flat_shapes),
             len_flat_args,
             args,
             flat_shapes))
    shapes_iter = iter(flat_shapes)
  for arg_value, name in zip(args, names):
    flattened = nest.flatten(arg_value)
    tensor_specs = [
        arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec)
    ]
    specified_names = [arg.name for arg in tensor_specs if arg.name]
    if specified_names and len(specified_names) < len(tensor_specs):
      raise ValueError("If specifying TensorSpec names for nested structures, "
                       "either zero or all names have to be specified.")

    for arg in flattened:
      # We have a shape entry for each arg, regadless of whether it's a real
      # Tensor or not.  For non-tensor entries it should be None.
      shape = next(shapes_iter)
      if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
        if isinstance(arg, tensor_spec.TensorSpec) and arg.name:
          requested_name = arg.name
        else:
          requested_name = name
        placeholder_shape = shape if shape is not None else arg.shape
        try:
          placeholder = graph_placeholder(
              arg.dtype, placeholder_shape,
              name=requested_name)
        except ValueError:
          # Sometimes parameter names are not valid op names, so fall back to
          # unnamed placeholders.
          placeholder = graph_placeholder(arg.dtype, placeholder_shape)
        if name is not None:
          # Record the requested/user-specified name in case it's different than
          # the uniquified name, for validation when exporting signatures.
          placeholder.op._set_attr(  # pylint: disable=protected-access
              "_user_specified_name",
              attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
        function_inputs.append(placeholder)
      elif isinstance(arg, resource_variable_ops.ResourceVariable):
        # Capture arg variables to create placeholders for them. These will be
        # removed as captures after the function is traced (since otherwise we'd
        # just add it back with a new placeholder when the variable was
        # referenced).
        placeholder = func_graph.capture(arg.handle, name=name)
        placeholder.op._set_attr(  # pylint: disable=protected-access
            "_user_specified_name",
            attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
        function_inputs.append(arg)
      else:
        if shape is not None:
          raise RuntimeError(
              "Expected provided shape override to be None for arg that isn't "
              "a Tensor, but saw arg: '%s', shape: '%s'.  args: %s"
              % (arg, shape, args))
        function_inputs.append(arg)
  return nest.pack_sequence_as(structure, function_inputs)
Example #50
0
 def output_pack(x):
     return (nest.pack_sequence_as(initializer, x)
             if output_is_sequence else x[0])
Example #51
0
def _pfor_impl(loop_fn,
               iters,
               fallback_to_while_loop,
               parallel_iterations=None,
               pfor_config=None):
    """Implementation of pfor."""
    assert not context.executing_eagerly()
    loop_fn_has_config = _loop_fn_has_config(loop_fn)
    existing_ops = set(ops.get_default_graph().get_operations())
    # Run the loop body
    with ops.name_scope("loop_body"):
        loop_var = array_ops.placeholder_with_default(0, shape=[])
        if loop_fn_has_config:
            if pfor_config is None:
                pfor_config = PForConfig()
                pfor_config._set_iters(iters)  # pylint: disable=protected-access
            loop_fn_outputs = loop_fn(loop_var,
                                      **{PFOR_CONFIG_ARG: pfor_config})
        else:
            assert pfor_config is None
            loop_fn_outputs = loop_fn(loop_var)

    # Convert outputs to Tensor if needed.
    rewrap_as_ndarray = False
    tmp_loop_fn_outputs = []
    for loop_fn_output in nest.flatten(loop_fn_outputs):
        if (loop_fn_output is not None and not isinstance(
                loop_fn_output,
            (ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
            if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
                logging.warn(
                    "Converting %s to a dense representation may make it slow."
                    " Alternatively, output the indices and values of the"
                    " IndexedSlices separately, and handle the vectorized"
                    " outputs directly." % loop_fn_output)
                loop_fn_output = ops.convert_to_tensor(loop_fn_output)
            elif isinstance(loop_fn_output, np_arrays.ndarray):
                loop_fn_output = loop_fn_output.data
                rewrap_as_ndarray = True
            else:
                loop_fn_output = ops.convert_to_tensor(loop_fn_output)
        tmp_loop_fn_outputs.append(loop_fn_output)
    loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs,
                                            tmp_loop_fn_outputs)

    new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
    iters = ops.convert_to_tensor(iters)
    if parallel_iterations is not None:
        if parallel_iterations < 1:
            raise ValueError(
                "parallel_iterations must be None or a positive integer")
        if parallel_iterations == 1:
            raise ValueError(
                "Found parallel_iterations == 1. Use for_loop instead.")
        iters_value = tensor_util.constant_value(iters)
        if iters_value is not None and iters_value < parallel_iterations:
            parallel_iterations = None
    if parallel_iterations is None:
        with ops.name_scope("pfor"):
            converter = PFor(loop_var,
                             iters,
                             new_ops,
                             fallback_to_while_loop=fallback_to_while_loop,
                             pfor_config=pfor_config)
            outputs = []
            for loop_fn_output in nest.flatten(loop_fn_outputs):
                output = converter.convert(loop_fn_output)
                if rewrap_as_ndarray:
                    output = np_arrays.tensor_to_ndarray(output)
                outputs.append(output)
            return nest.pack_sequence_as(loop_fn_outputs, outputs)
    else:
        if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
            raise ValueError(
                "Setting parallel_iterations currently unsupported if"
                " reductions across iterations are performed.")
        num_tiled_iterations = iters // parallel_iterations
        num_remaining_iterations = iters % parallel_iterations
        # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
        # a tf.function and extract the graph from there to vectorize it.
        with ops.name_scope("pfor_untiled"):
            converter = PFor(loop_var,
                             num_remaining_iterations,
                             new_ops,
                             fallback_to_while_loop=fallback_to_while_loop,
                             pfor_config=pfor_config)
            remaining_outputs = []
            flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
            for loop_fn_output in flattened_loop_fn_outputs:
                output = converter.convert(loop_fn_output)
                if rewrap_as_ndarray:
                    output = np_arrays.tensor_to_ndarray(output)
                remaining_outputs.append(output)

        with ops.name_scope("pfor_tiled"):
            loop_fn_dtypes = [
                ops.convert_to_tensor(x).dtype
                for x in flattened_loop_fn_outputs
            ]

            def tiled_loop_body(j):
                offset = j * parallel_iterations + num_remaining_iterations

                def tiled_loop_fn(i, pfor_config=None):
                    if loop_fn_has_config:
                        return nest.flatten(
                            loop_fn(i + offset, pfor_config=pfor_config))
                    else:
                        return nest.flatten(loop_fn(i + offset))

                return _pfor_impl(
                    tiled_loop_fn,
                    parallel_iterations,
                    fallback_to_while_loop=fallback_to_while_loop,
                    pfor_config=pfor_config)

            tiled_outputs = for_loop(tiled_loop_body,
                                     loop_fn_dtypes,
                                     num_tiled_iterations,
                                     parallel_iterations=1)
            tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs]

        with ops.name_scope("pfor"):
            iters_value = tensor_util.constant_value(iters)
            if iters_value is None or iters_value % parallel_iterations:
                outputs = control_flow_ops.cond(
                    math_ops.equal(num_remaining_iterations, 0),
                    lambda: tiled_outputs, lambda: [
                        array_ops.concat([x, y], axis=0)
                        for x, y in zip(remaining_outputs, tiled_outputs)
                    ])
            else:
                outputs = tiled_outputs
            flattened_outputs = nest.flatten(outputs)
            if rewrap_as_ndarray:
                flattened_outputs = [
                    np_arrays.tensor_to_ndarray(x) for x in flattened_outputs
                ]
            return nest.pack_sequence_as(loop_fn_outputs,
                                         nest.flatten(outputs))
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Remove None at the end of args as they are not replicatable
            # If there are None in the middle we can't do anything about it
            # so let those cases fail.
            # For example when Keras model predict is used they pass the targets as
            # None. We want to handle it here so all client libraries don't have to
            # do this as other strategies can handle None values better.
            while args and args[-1] is None:
                args = args[:-1]

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    values.select_replica(i, args),
                    values.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if options.experimental_enable_dynamic_batch_size and replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        rank = input_tensor.get_shape().rank
                    else:
                        rank = np.rank(input_tensor)
                    maximum_shape = tensor_shape.TensorShape([None] * rank)
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            if options.experimental_bucketizing_dynamic_shape:
                padding_spec = tpu.PaddingSpec.POWER_OF_TWO
            else:
                padding_spec = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    device_assignment=self._device_assignment,
                    maximum_shapes=maximum_shapes,
                    padding_spec=padding_spec)

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if not isinstance(output, ops.Operation)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None or isinstance(result[0], ops.Operation):
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            return values.regroup(replicate_outputs)
Example #53
0
 def _get_grads_lists_empirical(self, tensors):
     grads_flat = gradients_impl.gradients(self._layers.total_loss(),
                                           nest.flatten(tensors))
     grads_all = nest.pack_sequence_as(tensors, grads_flat)
     return tuple((grad, ) for grad in grads_all)
Example #54
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            iterator,
                                            iterations,
                                            initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = input_lib.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self._unwrap(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, wrap them in a Mirrored
            # container, else in a PerReplica container.
            if reduce_op is None:
                last_step_tensor_outputs_dict[name] = values.regroup(
                    self._device_map, output)
            else:
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
Example #55
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=False,
                scope=None):
    """Creates a recurrent neural network specified by RNNCell `cell`.

  This function is functionally identical to the function `rnn` above, but
  performs fully dynamic unrolling of `inputs`.

  Unlike `rnn`, the input `inputs` is not a Python list of `Tensors`, one for
  each frame.  Instead, `inputs` may be a single `Tensor` where
  the maximum time is either the first or second dimension (see the parameter
  `time_major`).  Alternatively, it may be a (possibly nested) tuple of
  Tensors, each of them having matching batch and time dimensions.
  The corresponding output is either a single `Tensor` having the same number
  of time steps and batch size, or a (possibly nested) tuple of such tensors,
  matching the nested structure of `cell.output_size`.

  The parameter `sequence_length` is optional and is used to copy-through state
  and zero-out outputs when past a batch element's sequence length. So it's more
  for correctness than performance, unlike in rnn().

  Args:
    cell: An instance of RNNCell.
    inputs: The RNN inputs.

      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`, or a nested tuple of such
        elements.

      If `time_major == True`, this must be a `Tensor` of shape:
        `[max_time, batch_size, ...]`, or a nested tuple of such
        elements.

      This may also be a (possibly nested) tuple of Tensors satisfying
      this property.  The first two dimensions must match across all the inputs,
      but otherwise the ranks and other shape components may differ.
      In this case, input to `cell` at each time-step will replicate the
      structure of these tuples, except for the time dimension (from which the
      time is taken).

      The input to `cell` at each time step will be a `Tensor` or (possibly
      nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
    sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
    initial_state: (optional) An initial state for the RNN.
      If `cell.state_size` is an integer, this must be
      a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
      If `cell.state_size` is a tuple, this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    dtype: (optional) The data type for the initial state and expected output.
      Required if initial_state is not provided or RNN state has a heterogeneous
      dtype.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency
      and can be run in parallel, will be.  This parameter trades off
      time for space.  Values >> 1 use more memory but take less time,
      while smaller values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs
      which would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors.
      If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
      If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
      Using `time_major = True` is a bit more efficient because it avoids
      transposes at the beginning and end of the RNN calculation.  However,
      most TensorFlow data is batch-major, so by default this function
      accepts input and emits output in batch-major form.
    scope: VariableScope for the created subgraph; defaults to "rnn".

  Returns:
    A pair (outputs, state) where:

      outputs: The RNN output `Tensor`.

        If time_major == False (default), this will be a `Tensor` shaped:
          `[batch_size, max_time, cell.output_size]`.

        If time_major == True, this will be a `Tensor` shaped:
          `[max_time, batch_size, cell.output_size]`.

        Note, if `cell.output_size` is a (possibly nested) tuple of integers
        or `TensorShape` objects, then `outputs` will be a tuple having the
        same structure as `cell.output_size`, containing Tensors having shapes
        corresponding to the shape data in `cell.output_size`.

      state: The final state.  If `cell.state_size` is an int, this
        will be shaped `[batch_size, cell.state_size]`.  If it is a
        `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
        If it is a (possibly nested) tuple of ints or `TensorShape`, this will
        be a tuple having the corresponding shapes.

  Raises:
    TypeError: If `cell` is not an instance of RNNCell.
    ValueError: If inputs is None or an empty list.
  """

    # pylint: disable=protected-access
    if not isinstance(cell, rnn_cell_impl._RNNCell):
        raise TypeError("cell must be an instance of RNNCell")
    # pylint: enable=protected-access

    # By default, time_major==False and inputs are batch-major: shaped
    #   [batch, time, depth]
    # For internal calculations, we transpose to [time, batch, depth]
    flat_input = nest.flatten(inputs)

    if not time_major:
        # (B,T,D) => (T,B,D)
        flat_input = tuple(
            array_ops.transpose(input_, [1, 0, 2]) for input_ in flat_input)

    parallel_iterations = parallel_iterations or 32
    if sequence_length is not None:
        sequence_length = math_ops.to_int32(sequence_length)
        if sequence_length.get_shape().ndims not in (None, 1):
            raise ValueError(
                "sequence_length must be a vector of length batch_size, "
                "but saw shape: %s" % sequence_length.get_shape())
        sequence_length = array_ops.identity(  # Just to find it in the graph.
            sequence_length,
            name="sequence_length")

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)
        input_shape = tuple(array_ops.shape(input_) for input_ in flat_input)
        batch_size = input_shape[0][1]

        for input_ in input_shape:
            if input_[1].get_shape() != batch_size.get_shape():
                raise ValueError("All inputs should have the same batch size")

        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError(
                    "If no initial_state is provided, dtype must be.")
            state = cell.zero_state(batch_size, dtype)

        def _assert_has_shape(x, shape):
            x_shape = array_ops.shape(x)
            packed_shape = array_ops.stack(shape)
            return control_flow_ops.Assert(
                math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
                    "Expected shape for Tensor %s is " % x.name, packed_shape,
                    " but saw shape: ", x_shape
                ])

        if sequence_length is not None:
            # Perform some shape validation
            with ops.control_dependencies(
                [_assert_has_shape(sequence_length, [batch_size])]):
                sequence_length = array_ops.identity(sequence_length,
                                                     name="CheckSeqLen")

        inputs = nest.pack_sequence_as(structure=inputs,
                                       flat_sequence=flat_input)

        (outputs, final_state) = _dynamic_rnn_loop(
            cell,
            inputs,
            state,
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory,
            sequence_length=sequence_length,
            dtype=dtype)

        # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
        # If we are performing batch-major calculations, transpose output back
        # to shape [batch, time, depth]
        if not time_major:
            # (T,B,D) => (B,T,D)
            flat_output = nest.flatten(outputs)
            flat_output = [
                array_ops.transpose(output, [1, 0, 2])
                for output in flat_output
            ]
            outputs = nest.pack_sequence_as(structure=outputs,
                                            flat_sequence=flat_output)

        return (outputs, final_state)
Example #56
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            autograph_options=None,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None):
  """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    autograph_options: additional knobs to control when `autograph=True`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
  if op_return_value is not None:
    assert isinstance(op_return_value, ops.Tensor), op_return_value
  if func_graph is None:
    func_graph = FuncGraph(name, collections=collections)
  assert isinstance(func_graph, FuncGraph)
  if add_control_dependencies:
    control_manager = AutomaticControlDependencies
  else:
    control_manager = ops.NullContextmanager
  with func_graph.as_default(), control_manager() as a:
    current_scope = variable_scope.get_variable_scope()
    default_use_recource = current_scope.use_resource
    current_scope.set_use_resource(True)

    if signature is not None:
      args = signature
      kwargs = {}

    # Creates and names placeholders for all arguments.
    func_args = _get_defun_inputs_from_args(args, arg_names)
    func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

    # Convert all Tensors into TensorSpecs before saving the structured inputs.
    # If storing pure concrete functions that are not called through polymorphic
    # functions, we don't have access to FunctionSpec, so we need to call the
    # TensorSpecs by their `arg_names` for later binding.
    func_graph.structured_input_signature = (
        convert_structure_to_signature(func_args, arg_names),
        convert_structure_to_signature(func_kwargs))

    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
    # Variables to help check whether mutation happens in calling the function
    # Copy the recursive list, tuple and map structure, but not base objects
    func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
    func_kwargs_before = nest.pack_sequence_as(
        func_kwargs, nest.flatten(func_kwargs))

    def convert(x):
      """Converts a function output to a Tensor."""
      if x is None:
        return None
      if op_return_value is not None and isinstance(x, ops.Operation):
        # TODO(b/79881896): we currently can't capture external control deps, so
        # this won't work if x needs to be captured (i.e. if python_func returns
        # captured Operations).
        with ops.control_dependencies([x]):
          x = array_ops.identity(op_return_value)
      elif not isinstance(x, tensor_array_ops.TensorArray):
        try:
          x = ops.convert_to_tensor_or_composite(x)
        except (ValueError, TypeError):
          raise TypeError(
              "To be compatible with tf.contrib.eager.defun, Python functions "
              "must return zero or more Tensors; in compilation of %s, found "
              "return value of type %s, which is not a Tensor." %
              (str(python_func), type(x)))
      if add_control_dependencies:
        x = a.mark_as_return(x)
      return x

    this_tape = tape.push_new_tape()
    try:
      if autograph:
        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
        _, original_func = tf_decorator.unwrap(python_func)

        def wrapper(*args, **kwargs):
          # Note: functions annotated with @tf.function should always be
          # converted even though they would meet autograph's whitelisting
          # criteria.
          # If this assumption is ever broken, converted_call will need to
          # handle the possibility of original_func still being a shim, e.g.
          # bound to WeakrefSelf.
          return autograph.converted_call(
              original_func, None,
              autograph.ConversionOptions(
                  verbose=autograph.Verbosity.BRIEF,
                  recursive=True,
                  strip_decorators=(def_function.function,),
                  optional_features=autograph_options,
                  force_conversion=True,
              ), *args, **kwargs)

        # Wrapping around a decorator allows checks like tf_inspect.getargspec
        # to be accurate.
        converted_func = tf_decorator.make_decorator(original_func, wrapper)
        tf_decorator.rewrap(python_func, original_func, converted_func)

      func_outputs = python_func(*func_args, **func_kwargs)

      # invariant: `func_outputs` contains only Tensors, IndexedSlices,
      # SparseTensors, TensorArrays and `None`s.
      func_outputs = nest.map_structure(convert, func_outputs)

      check_mutation(func_args_before, func_args)
      check_mutation(func_kwargs_before, func_kwargs)
    finally:
      tape.pop_tape(this_tape)
      current_scope.set_use_resource(default_use_recource)

    # Variables in `func_args`, `func_kwargs` should be explicit inputs
    # to the function, not captured inputs.
    tape_variables = this_tape.watched_variables()
    arg_variables = set()
    inputs = []
    for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
      if isinstance(arg, resource_variable_ops.ResourceVariable):
        # Even if an argument variable was not used in the function, we've
        # already manually captured the resource Tensor when creating argument
        # placeholders.
        resource_placeholder = func_graph.captures.pop(arg.handle, None)
        if resource_placeholder is None:
          continue
        arg_variables.add(arg)
        inputs.append(resource_placeholder)
      elif isinstance(arg, ops.Tensor):
        inputs.append(arg)
    variables = [v for v in tape_variables if v not in arg_variables]
    func_graph.inputs = inputs + list(func_graph.captures.values())

    func_graph.structured_outputs = func_outputs
    # Returning a closed-over tensor does not trigger convert_to_tensor.
    func_graph.outputs.extend(
        func_graph.capture(x)
        for x in flatten(func_graph.structured_outputs)
        if x is not None)

    func_graph.variables = variables

  # Register any other functions defined in the graph.
  with ops.init_scope():
    if context.executing_eagerly():
      for f in func_graph._functions.values():  # pylint: disable=protected-access
        # TODO(ashankar): What about the gradient registry?
        context.add_function(f._c_func.func)  # pylint: disable=protected-access

  return func_graph
Example #57
0
def pack_iterable_as(structure, flat_iterable):
    """See `nest.pack_sequence_as`. Provided for named-arg compatibility."""
    return nest.pack_sequence_as(structure, flat_iterable)
Example #58
0
def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
    """Internal implementation of Dynamic RNN.

  Args:
    cell: An instance of RNNCell.
    inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
      tuple of such elements.
    initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
      `cell.state_size` is a tuple, then this should be a tuple of
      tensors having shapes `[batch_size, s] for s in cell.state_size`.
    parallel_iterations: Positive Python int.
    swap_memory: A Python boolean
    sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
    dtype: (optional) Expected dtype of output. If not specified, inferred from
      initial_state.

  Returns:
    Tuple `(final_outputs, final_state)`.
    final_outputs:
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.
    final_state:
      A `Tensor`, or possibly nested tuple of Tensors, matching in length
      and shapes to `initial_state`.

  Raises:
    ValueError: If the input depth cannot be inferred via shape inference
      from the inputs.
  """
    state = initial_state
    assert isinstance(parallel_iterations,
                      int), "parallel_iterations must be int"

    state_size = cell.state_size

    flat_input = nest.flatten(inputs)
    flat_output_size = nest.flatten(cell.output_size)

    # Construct an initial output
    input_shape = array_ops.shape(flat_input[0])
    time_steps = input_shape[0]
    batch_size = input_shape[1]

    inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
                             for input_ in flat_input)

    const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]

    for shape in inputs_got_shape:
        if not shape[2:].is_fully_defined():
            raise ValueError(
                "Input size (depth of inputs) must be accessible via shape inference,"
                " but saw value None.")
        got_time_steps = shape[0].value
        got_batch_size = shape[1].value
        if const_time_steps != got_time_steps:
            raise ValueError(
                "Time steps is not the same for all the elements in the input in a "
                "batch.")
        if const_batch_size != got_batch_size:
            raise ValueError(
                "Batch_size is not the same for all the elements in the input."
            )

    # Prepare dynamic conditional copying of state & output
    def _create_zero_arrays(size):
        size = _state_size_with_prefix(size, prefix=[batch_size])
        return array_ops.zeros(array_ops.stack(size),
                               _infer_state_dtype(dtype, state))

    flat_zero_output = tuple(
        _create_zero_arrays(output) for output in flat_output_size)
    zero_output = nest.pack_sequence_as(structure=cell.output_size,
                                        flat_sequence=flat_zero_output)

    if sequence_length is not None:
        min_sequence_length = math_ops.reduce_min(sequence_length)
        max_sequence_length = math_ops.reduce_max(sequence_length)

    time = array_ops.constant(0, dtype=dtypes.int32, name="time")

    with ops.name_scope("dynamic_rnn") as scope:
        base_name = scope

    def _create_ta(name, dtype):
        return tensor_array_ops.TensorArray(dtype=dtype,
                                            size=time_steps,
                                            tensor_array_name=base_name + name)

    output_ta = tuple(
        _create_ta("output_%d" % i, _infer_state_dtype(dtype, state))
        for i in range(len(flat_output_size)))
    input_ta = tuple(
        _create_ta("input_%d" % i, flat_input[0].dtype)
        for i in range(len(flat_input)))

    input_ta = tuple(
        ta.unstack(input_) for ta, input_ in zip(input_ta, flat_input))

    def _time_step(time, output_ta_t, state):
        """Take a time step of the dynamic RNN.

    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """

        input_t = tuple(ta.read(time) for ta in input_ta)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)
        call_cell = lambda: cell(input_t, state)

        if sequence_length is not None:
            (output,
             new_state) = _rnn_step(time=time,
                                    sequence_length=sequence_length,
                                    min_sequence_length=min_sequence_length,
                                    max_sequence_length=max_sequence_length,
                                    zero_output=zero_output,
                                    state=state,
                                    call_cell=call_cell,
                                    state_size=state_size,
                                    skip_conditionals=True)
        else:
            (output, new_state) = call_cell()

        # Pack state if using state tuples
        output = nest.flatten(output)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))

        return (time + 1, output_ta_t, new_state)

    _, output_final_ta, final_state = control_flow_ops.while_loop(
        cond=lambda time, *_: time < time_steps,
        body=_time_step,
        loop_vars=(time, output_ta, state),
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    # Unpack final output if not using output tuples.
    final_outputs = tuple(ta.stack() for ta in output_final_ta)

    # Restore some shape information
    for output, output_size in zip(final_outputs, flat_output_size):
        shape = _state_size_with_prefix(
            output_size, prefix=[const_time_steps, const_batch_size])
        output.set_shape(shape)

    final_outputs = nest.pack_sequence_as(structure=cell.output_size,
                                          flat_sequence=final_outputs)

    return (final_outputs, final_state)
Example #59
0
def _graph_mode_decorator(f, *args, **kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()
    args = [ops.convert_to_tensor(x) for x in args]

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set(current_var_scope.global_variables() +
                      current_var_scope.local_variables())
    with backprop.GradientTape() as tape:
        result, grad_fn = f(*args)
    after_vars = set(current_var_scope.global_variables() +
                     current_var_scope.local_variables())
    new_vars = after_vars - before_vars
    for v in new_vars:
        if not isinstance(v, resource_variable_ops.ResourceVariable):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables = list(set(tape.watched_variables()) - set(args))
    grad_argspec = tf_inspect.getargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or grad_argspec.keywords)
    if variables and not variables_in_signature:
        raise TypeError("If using @custom_gradient with a function that "
                        "uses variables, then grad_fn must accept a keyword "
                        "argument 'variables'.")
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        if not variable_scope.get_variable_scope().use_resource:
            raise TypeError(
                "If using @custom_gradient with a function that "
                "uses variables, the enclosing variable scope must "
                "have use_resource=True.")
        else:
            logging.warn(
                "@custom_gradient grad_fn has 'variables' in signature, but "
                "no ResourceVariables were used on the forward pass.")
    flat_result = nest.flatten(result)
    all_tensors = flat_result + args + variables

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        result_grads = result_grads[:len(flat_result)]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * len(flat_result)) + input_grads + variable_grads

    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:len(flat_result)])
Example #60
0
        def body(time, elements_finished, current_input, emit_ta, state,
                 loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(lambda: array_ops.where(elements_finished,
                                                       current_i, candidate_i),
                               device=candidate_i.op.device)
                    for (current_i,
                         candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current,
                                             flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            emit_ta_flat = [
                ta.write(time, emit)
                for (ta, emit) in zip(emit_ta_flat, emit_output_flat)
            ]

            emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                            flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta,
                    next_state, loop_state)