Example #1
0
  def _batch_all_reduce(self, aggregation, per_device_values):
    """All reduce algorithm in a batch."""
    logging.log_first_n(
        logging.INFO, "batch_all_reduce invoked for batches size = %d with "
        "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
        "agg_small_grads_max_group = %d" %
        (len(per_device_values), self._all_reduce_alg, self._num_packs,
         self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)
    destinations = per_device_values[0].devices
    grouped = _group_value_by_device(per_device_values)

    device_grad_packs, tensor_packer = _pack_tensors(
        grouped, self._num_packs, self._agg_small_grads_max_bytes,
        self._agg_small_grads_max_group)

    # The actual aggregation of the repacked gradients. Note that they are
    # sharded among different aggregation trees. So it is important to strike
    # the balance on num_splits.
    if self._all_reduce_alg == "nccl":
      # TODO(yuefengz): merge this into the all-reduce library.
      reduced = cross_tower_utils.aggregate_gradients_using_nccl(
          device_grad_packs)
    else:
      # TODO(yuefengz): check that gpu ids in `destinations` are in ascending
      # order.
      reduced = (
          cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
              destinations, device_grad_packs))

    reduced = _unpack_tensors(reduced, tensor_packer)
    return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
                                      aggregation)
  def after_run(self, run_context, run_values):
    _ = run_context

    stale_global_step = run_values.results
    if self._timer.should_trigger_for_step(stale_global_step +
                                           self._steps_per_run):
      # get the real value after train op.
      global_step = run_context.session.run(self._global_step_tensor)
      if self._timer.should_trigger_for_step(global_step):
        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
            global_step)
        if elapsed_time is not None:
          self._log_and_record(elapsed_steps, elapsed_time, global_step)

    # Check whether the global step has been increased. Here, we do not use the
    # timer.last_triggered_step as the timer might record a different global
    # step value such that the comparison could be unreliable. For simplicity,
    # we just compare the stale_global_step with previously recorded version.
    if stale_global_step == self._last_global_step:
      # Here, we give a warning in the first 5 times if we have observed that
      # the global step has not been increased. For some Optimizers, the global
      # step is not increased each time by design. For example,
      # SyncReplicaOptimizer doesn't increase the global step in worker's main
      # train step.
      logging.log_first_n(
          logging.WARN,
          "It seems that global step (tf.train.get_global_step) has not "
          "been increased. Current value (could be stable): %s vs previous "
          "value: %s. You could increase the global step by passing "
          "tf.train.get_global_step() to Optimizer.apply_gradients or "
          "Optimizer.minimize.", 5, stale_global_step, self._last_global_step)

    self._last_global_step = stale_global_step
  def _ProcessHealthPillSummary(self, value, event):
    """Process summaries containing health pills.

    These summaries are distinguished by the fact that they have a Tensor field
    and have a special tag value.

    This method emits ERROR-level messages to the logs if it encounters Tensor
    summaries that it cannot process.

    Args:
      value: A summary_pb2.Summary.Value with a Tensor field.
      event: The event_pb2.Event containing that value.
    """
    elements = np.fromstring(value.tensor.tensor_content, dtype=np.float64)

    # The node_name property of the value object is actually a watch key: a
    # combination of node name, output slot, and a suffix. We capture the
    # actual node name and the output slot with a regular expression.
    match = re.match(r'^(.*):(\d+):DebugNumericSummary$', value.node_name)
    if not match:
      logging.log_first_n(
          logging.ERROR,
          'Unsupported watch key %s for health pills; skipping this sequence.',
          1,
          value.node_name)
      return

    node_name = match.group(1)
    output_slot = int(match.group(2))
    self._ProcessHealthPill(
        event.wall_time, event.step, node_name, output_slot, elements)
  def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values):
    """All-reduce across all workers in a batch."""

    logging.log_first_n(
        logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
        "num_workers = %d" % (len(per_replica_values), self._num_workers), 10)

    chunked_gv = self._make_gradient_chunks(per_replica_values,
                                            self._all_reduce_merge_scope)

    reduced_gv_list = []
    for chunk in chunked_gv:
      with ops.name_scope("allreduce"):
        for grad_and_vars in chunk:
          # Gradients for the same variable but from different devices.
          scaled_grads = [g for g, _ in grad_and_vars]
          collective_reduced = cross_device_utils.build_collective_reduce(
              scaled_grads, self._num_workers, self._collective_keys, "Add",
              "Id")
          result = []
          for (_, v), g in zip(grad_and_vars, collective_reduced):
            result.append([g, v])
          reduced_gv_list.append(result)

    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
    return _ungroup_and_make_mirrored(
        new_device_grads,
        per_replica_values[0],
        reduce_op,
        num_between_graph_workers=self._num_workers)
 def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
   """Run batch all-reduce for sparse values."""
   logging.log_first_n(
       logging.WARN,
       "Efficient allreduce is not supported for %d IndexedSlices" %
       len(sparse_values), 10)
   # Use `sparse_values` as destinations to do all-reduces. It is effectively
   # an allgather under the hood but not an efficient one.
   return self._simple_cross_replica_ops.batch_reduce(
       reduce_op, zip(sparse_values, sparse_values))
Example #6
0
 def _reduce(self, reduce_op, per_replica_value, destinations):
   assert check_destinations(destinations)
   devices = get_devices_from(destinations)
   reduce_to_device = self.reduce_to_device or devices[0]
   logging.log_first_n(
       logging.INFO,
       "Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
   reduced = _simple_reduce(per_replica_value, reduce_to_device,
                            self.accumulation_fn, reduce_op)
   return self.broadcast(reduced, destinations)
Example #7
0
  def gradient(self, target, sources, output_gradients=None):
    """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
    """
    if self._tape is None:
      raise RuntimeError("GradientTape.gradient can only be called once on "
                         "non-persistent tapes.")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(logging.WARN,
                            "Calling GradientTape.gradient on a persistent "
                            "tape inside it's context is significantly less "
                            "efficient than calling it outside the context (it "
                            "causes the gradient ops to be recorded on the "
                            "tape, leading to increased CPU and memory usage). "
                            "Only call GradientTape.gradient inside the "
                            "context if you actually want to trace the "
                            "gradient in order to compute higher order "
                            "derrivatives.", 1)

    flat_sources = nest.flatten(sources)
    flat_sources = [_handle_or_self(x) for x in flat_sources]

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        nest.flatten(target),
        flat_sources,
        output_gradients=output_gradients)

    if not self._persistent:
      self._tape = None

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
  def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
    """All-reduce IndexedSlices across all workers in a batch."""

    logging.log_first_n(
        logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
        "%d all-reduces, num_workers = %d" %
        (len(per_replica_values), self._num_workers), 10)

    chunked_gv = self._make_gradient_chunks(per_replica_values,
                                            self._all_reduce_merge_scope)

    reduced_gv_list = []
    for chunk in chunked_gv:
      with ops.name_scope("allreduce"):
        for grad_and_vars in chunk:
          # Gradients for the same variable but from different devices.
          scaled_grads = [g for g, _ in grad_and_vars]

          values = [g.values for g in scaled_grads]
          indices = [g.indices for g in scaled_grads]
          assert len(values) == len(indices)

          # Build two separate allgathers, one for values, the other one for
          # indices.
          gathered_values = cross_device_utils.build_collective_gather(
              values, self._num_workers, self._collective_keys)
          gathered_indices = cross_device_utils.build_collective_gather(
              indices, self._num_workers, self._collective_keys)
          assert len(gathered_values) == len(gathered_indices)

          collective_reduced = []
          for i in range(len(values)):
            reduced = ops.IndexedSlices(
                gathered_values[i],
                gathered_indices[i],
                dense_shape=scaled_grads[i].dense_shape)
            collective_reduced.append(reduced)

          result = []
          for (_, v), g in zip(grad_and_vars, collective_reduced):
            result.append([g, v])
          reduced_gv_list.append(result)

    new_device_grads = [list(x) for x in zip(*reduced_gv_list)]
    return _ungroup_and_make_mirrored(
        new_device_grads,
        per_replica_values[0],
        reduce_op,
        num_between_graph_workers=self._num_workers)
Example #9
0
  def _batch_all_reduce(self, aggregation, per_device_values):
    """All reduce algorithm in a batch."""
    logging.log_first_n(
        logging.INFO,
        "distributed batch_all_reduce invoked for batches size = %d with "
        "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
        "and agg_small_grads_max_group = %d" %
        (len(per_device_values), self._all_reduce_spec, self._num_packs,
         self._agg_small_grads_max_bytes, self._agg_small_grads_max_group), 10)

    destinations = sorted(per_device_values[0].devices)
    device_grads = _group_value_by_device(per_device_values)

    # The all reduce library requires fully defined shapes.
    # TODO(yuefengz): when tensor sharding is not needed, static shapes are not
    # required as well.
    for device_grad in device_grads:
      for grad, _ in device_grad:
        if not grad.shape.is_fully_defined():
          raise ValueError("Shape is unknown for node %r" % grad)

    remaining_grads = device_grads
    aggregated_grads = []
    for spec_tuple in self._all_reduce_spec:
      if spec_tuple.limit < 0:
        this_grads = remaining_grads
        remaining_grads = []
      else:
        (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
            spec_tuple.limit, remaining_grads)
      if this_grads:
        device_grad_packs, tensor_packer = _pack_tensors(
            this_grads, self._num_packs, self._agg_small_grads_max_bytes,
            self._agg_small_grads_max_group)
        range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
            self._worker_devices, device_grad_packs, len(self._worker_devices),
            spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
        range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)

        if not aggregated_grads:
          aggregated_grads = range_agg_grads
        else:
          assert len(aggregated_grads) == len(range_agg_grads)
          for i in range(len(aggregated_grads)):
            aggregated_grads[i] += range_agg_grads[i]
    assert not remaining_grads

    return _ungroup_and_make_mirrored(aggregated_grads, destinations,
                                      aggregation)
  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
    all_devices_match = _all_devices_match(value_destination_pairs)
    if all_devices_match:
      return self._batch_all_reduce(reduce_op,
                                    [v[0] for v in value_destination_pairs])
    else:
      if not all_devices_match:
        logging.log_first_n(
            logging.WARN, "Efficient batch_reduce is not supported if "
            "destinations are different.", 10)

      return [
          self.reduce_implementation(reduce_op, t, destinations=v)
          for t, v in value_destination_pairs
      ]
Example #11
0
 def run_loop(self):
     # Count the steps.
     current_step = training_util.global_step(self._sess, self._sv.global_step)
     added_steps = current_step - self._last_step
     self._last_step = current_step
     # Measure the elapsed time.
     current_time = time.time()
     elapsed_time = current_time - self._last_time
     self._last_time = current_time
     # Reports the number of steps done per second
     steps_per_sec = added_steps / elapsed_time
     summary = Summary(value=[Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)])
     if self._sv.summary_writer:
         self._sv.summary_writer.add_summary(summary, current_step)
     logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag, steps_per_sec)
  def _reduce(self, aggregation, per_device_value, destinations):
    contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
        per_device_value)
    if ((destinations is None or _devices_match(per_device_value, destinations))
        and not context.executing_eagerly()
        and not contains_indexed_slices):
      return self._batch_all_reduce(aggregation, [per_device_value])[0]
    else:
      if contains_indexed_slices:
        logging.log_first_n(
            logging.WARN,
            "Efficient allreduce is not supported for IndexedSlices.", 10)

      devices = get_devices_from(destinations or per_device_value)
      reduce_to_device = devices[0]
      reduced = _simple_reduce(per_device_value, reduce_to_device,
                               math_ops.add_n, aggregation)
      return self.broadcast(reduced, devices)
Example #13
0
  def watch(self, tensor):
    """Ensures that `tensor` is being traced by this tape.

    Args:
      tensor: a Tensor or list of Tensors.
    """
    for t in nest.flatten(tensor):
      if not t.dtype.is_floating:
        logging.log_first_n(
            logging.WARN, "The dtype of the watched tensor must be "
            "floating (e.g. tf.float32), got %r", 5, t.dtype)
      if hasattr(t, "handle"):
        # There are many variable-like objects, all of them currently have
        # `handle` attribute that points to a tensor. If this changes, internals
        # of watch_variable need to change as well.
        tape.watch_variable(self._tape, t)
      else:
        tape.watch(self._tape, t)
Example #14
0
  def _batch_all_reduce(self, aggregation, per_device_values):
    """All-reduce across all workers in a batch."""
    if context.executing_eagerly():
      raise ValueError(
          "Eager execution with collective ops is not supported yet.")

    logging.log_first_n(
        logging.INFO, "Collective All-reduce invoked with batches size = %d, "
        "num_workers = %d" % (len(per_device_values), self._num_workers), 10)

    grouped_by_tower = _group_value_by_device(per_device_values)

    grouped_by_var = list(zip(*grouped_by_tower))
    # grouped_by_var is grouped by variables and takes the following format:
    # [((grad0_gpu0, v0_gpu0), (grad0_gpu1, v0_gpu1), (grad0_gpu2, v0_gpu2) ..),
    #  ((grad1_gpu0, v1_gpu0), (grad1_gpu1, v1_gpu1), (grad1_gpu0, v1_gpu2) ..),
    #  ((grad2_gpu0, v2_gpu0), (grad2_gpu1, v2_gpu1), (grad2_gpu0, v2_gpu2) ..),
    #  ...
    # ]
    chunked_gv = [
        grouped_by_var[x:x + self._all_reduce_merge_scope]
        for x in range(0, len(grouped_by_var), self._all_reduce_merge_scope)
    ]

    reduced_gv_list = []
    for chunk in chunked_gv:
      with ops.name_scope("allreduce"):
        for grad_and_vars in chunk:
          scaled_grads = [g for g, _ in grad_and_vars]
          collective_reduced = cross_tower_utils.build_collective_reduce(
              scaled_grads, self._num_workers, self._collective_keys, "Add",
              "Id")
          result = []
          for (_, v), g in zip(grad_and_vars, collective_reduced):
            result.append([g, v])
          reduced_gv_list.append(result)

    new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
    return _ungroup_and_make_mirrored(
        new_tower_grads,
        per_device_values[0].devices,
        aggregation,
        num_between_graph_workers=self._num_workers)
  def _batch_all_reduce(self, reduce_op, per_replica_values):
    """All reduce algorithm in a batch."""
    logging.log_first_n(
        logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
        "num_workers = %d" % (len(per_replica_values), self._num_workers), 10)

    dense_values, dense_indices, sparse_values, sparse_indices = (
        cross_device_utils.split_by_sparsity(per_replica_values))
    if dense_values:
      dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values)
    else:
      dense_results = []
    if sparse_values:
      sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
                                                        sparse_values)
    else:
      sparse_results = []
    return cross_device_utils.stitch_values(((dense_results, dense_indices),
                                             (sparse_results, sparse_indices)))
Example #16
0
  def _batch_reduce(self, aggregation, value_destination_pairs):
    all_devices_match = _all_devices_match(value_destination_pairs)
    contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
        value_destination_pairs)
    if (all_devices_match and not context.executing_eagerly()
        and not contains_indexed_slices):
      return self._batch_all_reduce(aggregation,
                                    [v[0] for v in value_destination_pairs])
    else:
      if not all_devices_match:
        logging.log_first_n(logging.WARN,
                            "Efficient batch_reduce is not supported if "
                            "destinations are different.",
                            10)

      return [
          self._reduce(aggregation, t, destinations=v)
          for t, v in value_destination_pairs
      ]
Example #17
0
  def batch_reduce_implementation(self, reduce_op, value_destination_pairs):
    if cross_device_utils.contains_indexed_slices(value_destination_pairs):
      raise ValueError(
          "`IndexSlices` is not supported for Collective All-Reduce.")

    all_devices_match = _all_devices_match(value_destination_pairs)
    if all_devices_match:
      return self._batch_all_reduce(reduce_op,
                                    [v[0] for v in value_destination_pairs])
    else:
      if not all_devices_match:
        logging.log_first_n(
            logging.WARN, "Efficient batch_reduce is not supported if "
            "destinations are different.", 10)

      return [
          self.reduce_implementation(reduce_op, t, destinations=v)
          for t, v in value_destination_pairs
      ]
  def _reduce(self, reduce_op, per_replica_value, destinations):
    contains_indexed_slices = cross_device_utils.contains_indexed_slices(
        per_replica_value)
    if (_devices_match(per_replica_value, destinations)
        and not context.executing_eagerly()
        and not contains_indexed_slices):
      return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
    else:
      if contains_indexed_slices:
        logging.log_first_n(
            logging.WARN,
            "Efficient allreduce is not supported for IndexedSlices.", 10)

      if check_destinations(destinations):
        devices = get_devices_from(destinations)
      else:
        devices = get_devices_from(per_replica_value)
      reduce_to_device = devices[0]
      reduced = _simple_reduce(per_replica_value, reduce_to_device,
                               math_ops.add_n, reduce_op)
      return self.broadcast(reduced, devices)
Example #19
0
  def _batch_reduce(self, aggregation, value_destination_pairs):
    if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
      raise ValueError(
          "`IndexSlices` is not supported for Collective All-Reduce.")
    if context.executing_eagerly():
      raise ValueError(
          "Eager execution is not supported for Collective All-Reduce")

    all_devices_match = _all_devices_match(value_destination_pairs)
    if all_devices_match:
      return self._batch_all_reduce(aggregation,
                                    [v[0] for v in value_destination_pairs])
    else:
      if not all_devices_match:
        logging.log_first_n(
            logging.WARN, "Efficient batch_reduce is not supported if "
            "destinations are different.", 10)

      return [
          self._reduce(aggregation, t, destinations=v)
          for t, v in value_destination_pairs
      ]
Example #20
0
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)
  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      return True

  if hasattr(o, 'autograph_info__'):
    return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.log_first_n(
          logging.level_warning(),
          'Entity {} looks like a namedtuple subclass. If it has any custom'
          ' methods, they will not be converted by AutoGraph.'.format(o), 1)
    return True

  return False
  def watch(self, tensor):
    """Ensures that `tensor` is being traced by this tape.

    Args:
      tensor: a Tensor or list of Tensors.

    Raises:
      ValueError: if it encounters something that is not a tensor.
    """
    for t in nest.flatten(tensor, expand_composites=True):
      if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
        raise ValueError("Passed in object of type {}, not tf.Tensor".format(
            type(t)))
      if not backprop_util.IsTrainable(t):
        logging.log_first_n(
            logging.WARN, "The dtype of the watched tensor must be "
            "floating (e.g. tf.float32), got %r", 5, t.dtype)
      if hasattr(t, "handle"):
        # There are many variable-like objects, all of them currently have
        # `handle` attribute that points to a tensor. If this changes, internals
        # of watch_variable need to change as well.
        tape.watch_variable(self._tape, t)
      else:
        tape.watch(self._tape, t)
Example #22
0
  def gradient(self,
               target,
               sources,
               output_gradients=None,
               unconnected_gradients=UnconnectedGradients.NONE):
    """Computes the gradient using operations recorded in context of this tape.

    Note: Unless you set `persistent=True` a GradientTape can only be used to
    compute one set of gradients (or jacobians).

    Args:
      target: a list or nested structure of Tensors or Variables to be
        differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: If called on a used, non-persistent tape.
      RuntimeError: If called inside the context of the tape.
      ValueError: If the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
    if self._tape is None:
      raise RuntimeError("A non-persistent GradientTape can only be used to"
                         "compute one set of gradients (or jacobians)")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(
            logging.WARN, "Calling GradientTape.gradient on a persistent "
            "tape inside its context is significantly less "
            "efficient than calling it outside the context (it "
            "causes the gradient ops to be recorded on the "
            "tape, leading to increased CPU and memory usage). "
            "Only call GradientTape.gradient inside the "
            "context if you actually want to trace the "
            "gradient in order to compute higher order "
            "derivatives.", 1)

    num_ndarrays = 0
    flat_targets = []
    for t in nest.flatten(target):
      if not backprop_util.IsTrainable(t):
        logging.vlog(
            logging.WARN, "The dtype of the target tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if resource_variable_ops.is_resource_variable(t):
        with self:
          t = ops.convert_to_tensor(t)
      elif isinstance(t, np_arrays.ndarray):
        t = t.data
        num_ndarrays += 1
      flat_targets.append(t)
    # Only rewrap if all targets are ndarray. If not, prefer tensors.
    rewrap_as_ndarray = num_ndarrays == len(flat_targets)

    flat_sources = nest.flatten(sources)
    flat_sources_raw = flat_sources
    flat_sources = [_handle_or_self(x) for x in flat_sources]
    for t in flat_sources_raw:
      if not backprop_util.IsTrainable(t):
        logging.vlog(
            logging.WARN, "The dtype of the source tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if getattr(t, "is_packed", False):
        raise ValueError(
            "GradientTape.gradient is not supported on packed EagerTensors yet."
        )

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients,
        sources_raw=flat_sources_raw,
        unconnected_gradients=unconnected_gradients)

    if not self._persistent:
      # Keep track of watched variables before setting tape to None
      self._watched_variables = self._tape.watched_variables()
      self._tape = None

    if rewrap_as_ndarray:
      def _tensor_to_ndarray(x):
        if x is not None:
          return np_arrays.tensor_to_ndarray(x)
        return None
      flat_grad = nest.map_structure(_tensor_to_ndarray, flat_grad)

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
Example #23
0
 def disable_partitioned_variables(getter, *args, **kwargs):
   if kwargs.pop("partitioner", None) is not None:
     tf_logging.log_first_n(
         tf_logging.WARN, "Partitioned variables are disabled when using "
         "DistributionStrategy.", 1)
   return getter(*args, **kwargs)
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None):
  """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  ```python
  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  @tf.contrib.eager.defun
  def func(tensor):
    return tf.map_fn(fn, tensor)
  ```

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `dtype` if one is provided, otherwise it must have the
      same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unpacked along their first dimension.  The nested sequence of the
      resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`. Use
      `RaggedTensorType` to declare an output of type `RaggedTensor`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A possibly nested sequence of potentially ragged tensors.  Each
    tensor packs the results of applying `fn` to tensors unpacked from `elems`
    along the first dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  #### Examples:

    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]])
    mean = map_fn(tf.reduce_mean, elems)
    # mean == [2, 4, 6]
    ```

    ```python
    elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64)
    out = map_fn(fn=lambda x: x+1, elems,
      dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0))
    # out = ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
    ```
  """
  if not callable(fn):
    raise TypeError("fn must be callable.")

  if isinstance(elems, sparse_tensor.SparseTensor):
    raise TypeError(
        "To perform a map on the values of a sparse tensor use either "
        " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
        " SparseTensor(input.indices, map_fn(fn, input.values), "
        "input.dense_shape)")

  in_graph_mode = not context.executing_eagerly()
  # Set the default number of parallel_iterations depending on graph/eager mode.
  if in_graph_mode and not parallel_iterations:
    parallel_iterations = 10
  elif not in_graph_mode and not parallel_iterations:
    parallel_iterations = 1

  if not in_graph_mode and parallel_iterations > 1:
    logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
                        "effect when executing eagerly. Consider calling map_fn"
                        " with tf.contrib.eager.defun to execute fn in "
                        "parallel.", 1)
    parallel_iterations = 1

  input_is_sequence = nest.is_sequence(elems)
  input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

  def input_pack(x):
    return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

  elems_flat = input_flatten(elems)

  with ops.name_scope(name, "map", elems_flat):
    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode:
      # Any get_variable calls in fn will cache the first call locally
      # and not issue repeated network I/O requests for each iteration.
      varscope = vs.get_variable_scope()
      varscope_caching_device_was_none = False
      if varscope.caching_device is None:
        # TODO(ebrevdo): Change to using colocate_with here and in other
        # methods.
        varscope.set_caching_device(lambda op: op.device)
        varscope_caching_device_was_none = True

    elems_flat = [
        ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
            elem, name="elem") for elem in elems_flat
    ]

    # We can either infer the output, or we can assume that it will be the same
    # as the input structure.
    dtype = dtype or input_pack([elem.dtype for elem in elems_flat])

    # Find the number of iterations, n may be known statically.
    if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
      n = ragged_array_ops.nrows(elems_flat[0], out_type=dtypes.int32)
    else:
      static_shape = elems_flat[0].shape
      if static_shape.ndims is not None and static_shape.ndims < 1:
        if len(elems_flat) == 1:
          raise ValueError(
              "elems must be a 1+ dimensional Tensor, not a scalar")
        else:
          raise ValueError(
              "elements in elems must be 1+ dimensional Tensors, not scalars")
      n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]

    # Create a flat list of TAs.

    # Flatten the dtype structure to a list.
    dtype_flat = nest.flatten(dtype)

    # decompose to components
    dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
    dtype_components_flat = nest.flatten(dtype_components)

    # Create TensorArrays.
    accs_ta = [
        tensor_array_ops.TensorArray(
            dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n)
        for t in dtype_components_flat
    ]

    i = constant_op.constant(0)

    def compute(i, tas):
      """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
      # Get Tensors or RaggedTensors sliced at i, then pack it back to the
      # original structure.
      packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat])
      packed_fn_values = fn(packed_values)

      # Check that the structure of the output matches what was declared or
      # inferred.
      # nest.assert_same_structure(dtype or elems, packed_fn_values)

      # Flatten and decompose to a list of Tensors
      flat_fn_values = nest.flatten(packed_fn_values)

      # If we declared that we are expecting a RaggedTensor output, but we get a
      # Tensor output. We should try to convert it to a RaggedTensor.
      flat_fn_composite_tensors = list(
          _convert_declared(flat_fn_values, dtype_flat))

      flat_fn_components = [
          _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
      ]
      flat_fn_tensors = nest.flatten(flat_fn_components)

      # Write to TAs.
      tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)]

      return (i + 1, tas)

    _, r_a = control_flow_ops.while_loop(
        lambda i, _: i < n, compute, (i, accs_ta),
        parallel_iterations=parallel_iterations,
        back_prop=back_prop,
        swap_memory=swap_memory,
        maximum_iterations=n)

    # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    # supported in Eager
    if in_graph_mode and varscope_caching_device_was_none:
      varscope.set_caching_device(None)

    # Pack back into a list of components
    results_as_components = nest.pack_sequence_as(dtype_components, r_a)

    # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
    def _stack_or_concat(e):
      if isinstance(e, _RaggedTensorComponents):
        return _concat_ragged_tensor_components(e)
      else:
        result = e.stack()
        return result

    results_flat_components = [
        _stack_or_concat(e) for e in results_as_components
    ]

    results_packed = [
        _maybe_recompose_tensor(c) for c in results_flat_components
    ]
    results_packed = nest.pack_sequence_as(dtype, results_packed)
    return results_packed
Example #25
0
    def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
                                   experimental_hints):
        """All-reduce across all workers in a batch."""

        batch_size = len(per_replica_values)
        # Pass self._communication to the runtime as a communication hint.
        communication = self._communication.value
        # For now, we use NCCL only when batch_size > 1.
        # TODO(b/132575814): switch to NCCL for all collectives when communication
        # is NCCL.
        if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
            communication = CollectiveCommunication.AUTO.value

        # Reverse the lists so that there's better chance that values follows
        # the order in which they are calculated (e.g. when they're gradients), so
        # as to overlap calculation with communication. However, this may not be
        # optimal for cases like gradients of complicated non-sequential models.
        #
        # Note that we reverse the list before packing so that the first pack won't
        # be too small, since it's more likely for first few packs to have long
        # queuing time due to concurrent intense computation.
        #
        # TODO(b/147393503): explore solutions for optimal ordering.
        packs = cross_device_utils.pack_by_size(
            list(reversed(per_replica_values)),
            experimental_hints.bytes_per_pack)

        if batch_size > 1:
            logging.info(
                "Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
                "group_size = %d, communication_hint = %s, num_packs = %d",
                batch_size, len(self._devices), self._group_size,
                communication, len(packs))
        else:
            logging.log_first_n(
                logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
                "num_devices = %d, group_size = %d, communication_hint = %s, "
                "num_packs = %d" %
                (batch_size, len(self._devices), self._group_size,
                 communication, len(packs)), 10)

        reduced_values = []
        with self._lock:
            for pack in packs:
                # By placing all CollectiveReduce ops in a pack under single name scope,
                # we ensure they will be picked up by the `ScopedAllocator` grappler
                # optimizer and packed into a single all-reduce.
                with ops.name_scope("allreduce"):
                    for per_replica in pack:
                        # Add control dependencies per device from the last gradients to the
                        # current set, in order to serialize NCCL launches.
                        if (communication == CollectiveCommunication.NCCL.value
                                and reduced_values):
                            control_inputs = list(reduced_values[-1])
                        else:
                            control_inputs = None
                        reduced_values.append(
                            cross_device_utils.build_collective_reduce(
                                per_replica.values,
                                self._devices,
                                self._group_size,
                                self._collective_keys,
                                "Add",
                                "Id",
                                communication,
                                control_inputs,
                                executors=self._executors,
                                timeout=experimental_hints.timeout_seconds))

        for e in self._executors:
            e.wait()

        mirrored = []
        # Reverse the order of reduced value to recover the order in the input.
        for value in reversed(reduced_values):
            if reduce_op == reduce_util.ReduceOp.MEAN:
                for i, v in enumerate(value):
                    with ops.device(v.device):
                        value[i] = v / self._group_size
            mirrored.append(
                distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
        return mirrored
Example #26
0
 def LogErrorOnce(msg):
     logging.log_first_n(logging.ERROR, msg, 1)
Example #27
0
  def gradient(self, target, sources, output_gradients=None):
    """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if called on variable target.
    """
    if self._tape is None:
      raise RuntimeError("GradientTape.gradient can only be called once on "
                         "non-persistent tapes.")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(
            logging.WARN, "Calling GradientTape.gradient on a persistent "
            "tape inside its context is significantly less "
            "efficient than calling it outside the context (it "
            "causes the gradient ops to be recorded on the "
            "tape, leading to increased CPU and memory usage). "
            "Only call GradientTape.gradient inside the "
            "context if you actually want to trace the "
            "gradient in order to compute higher order "
            "derivatives.", 1)

    flat_targets = []
    for t in nest.flatten(target):
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the target tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if resource_variable_ops.is_resource_variable(t):
        with self:
          t = ops.convert_to_tensor(t)
      flat_targets.append(t)

    flat_sources = nest.flatten(sources)
    flat_sources_raw = flat_sources
    flat_sources = [_handle_or_self(x) for x in flat_sources]
    for t in flat_sources_raw:
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the source tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients,
        sources_raw=flat_sources_raw,
        unconnected_gradients=unconnected_gradients)

    if not self._persistent:
      self._tape = None

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
Example #28
0
  def gradient(self,
               target,
               sources,
               output_gradients=None,
               unconnected_gradients=UnconnectedGradients.NONE):
    """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
    if self._tape is None:
      raise RuntimeError("GradientTape.gradient can only be called once on "
                         "non-persistent tapes.")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(
            logging.WARN, "Calling GradientTape.gradient on a persistent "
            "tape inside its context is significantly less "
            "efficient than calling it outside the context (it "
            "causes the gradient ops to be recorded on the "
            "tape, leading to increased CPU and memory usage). "
            "Only call GradientTape.gradient inside the "
            "context if you actually want to trace the "
            "gradient in order to compute higher order "
            "derivatives.", 1)

    flat_targets = []
    for t in nest.flatten(target):
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the target tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if resource_variable_ops.is_resource_variable(t):
        with self:
          t = ops.convert_to_tensor(t)
      flat_targets.append(t)

    flat_sources = nest.flatten(sources)
    flat_sources_raw = flat_sources
    flat_sources = [_handle_or_self(x) for x in flat_sources]
    for t in flat_sources_raw:
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the source tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients,
        sources_raw=flat_sources_raw,
        unconnected_gradients=unconnected_gradients)

    if not self._persistent:
      self._tape = None

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
Example #29
0
def warn_first_n(msg, *args, **kwargs):
  logging.log_first_n(logging.WARNING, msg, *args, **kwargs)
Example #30
0
def dropin(x, rate, noise_shape=None, seed=None, name=None):
    """Computes dropin.
    With probability `rate`, drops in elements of `x`. Input that are dropped in are
    scaled up by `1 / (1 - rate)`, otherwise outputs `0`.    The scaling is so that
    the expected sum is unchanged.
    
    Args:
        x: A floating point tensor.
        rate: A scalar `Tensor` with the same type as x. The probability
            that each element is dropped. For example, setting rate=0.1 would drop
            10% of input elements.
        noise_shape: A 1-D `Tensor` of type `int32`, representing the
            shape for randomly generated keep/drop flags.
        seed: A Python integer. Used to create random seeds. See
            `tf.compat.v1.set_random_seed` for behavior.
        name: A name for this operation (optional).
    Returns:
        A Tensor of the same shape of `x`.
    Raises:
        ValueError: If `rate` is not in `(0, 1]` or if `x` is not a floating point
            tensor.
    """
    with ops.name_scope(name, "dropin", [x]) as name:
        x = ops.convert_to_tensor(x, name="x")
        if not x.dtype.is_floating:
            raise ValueError(
                "x has to be a floating point tensor since it's going to"
                " be scaled. Got a %s tensor instead." % x.dtype)
        if isinstance(rate, numbers.Real):
            if not (rate >= 0 and rate < 1):
                raise ValueError(
                    "rate must be a scalar tensor or a float in the "
                    "range [0, 1), got %g" % rate)
            if rate < 0.5:
                logging.log_first_n(
                    logging.WARN, "Low dropin rate: %g (<0.5). In TensorFlow "
                    "2.x, dropin() uses dropin rate instead of keep_prob. "
                    "Please ensure that this is intended.", 5, rate)

        # Early return if nothing needs to be dropped.
        if isinstance(rate, numbers.Real) and rate == 0:
            return x
        if context.executing_eagerly():
            if isinstance(rate, ops.EagerTensor):
                if rate.numpy() == 0:
                    return x
        else:
            rate = ops.convert_to_tensor(rate, dtype=x.dtype, name="rate")
            rate.get_shape().assert_is_compatible_with(tensor_shape.scalar())

            # Do nothing if we know rate == 0
            if tensor_util.constant_value(rate) == 0:
                return x

        noise_shape = _get_noise_shape(x, noise_shape)
        # Sample a uniform distribution on [0.0, 1.0) and select values larger than
        # rate.
        #
        # NOTE: Random uniform actually can only generate 2^23 floats on [1.0, 2.0)
        # and subtract 1.0.
        random_tensor = random_ops.random_uniform(noise_shape,
                                                  seed=seed,
                                                  dtype=x.dtype)

        scale = 1 / rate
        # NOTE: if (1.0 + rate) - 1 is equal to rate, then we want to consider that
        # float to be selected, hence we use a >= comparison.
        keep_mask = random_tensor < rate
        ret = x * scale * math_ops.cast(keep_mask, x.dtype)
        if not context.executing_eagerly():
            ret.set_shape(x.get_shape())

        return ret
Example #31
0
def _lift_unlifted_variables(graph, variable_holder):
  """Finds resource variables and lifts them into the outer context.

  When we import a GraphDef inside a wrap_function, no Python graph building
  code runs. This means we get VarHandleOps which create variable resources,
  but no corresponding Python objects. Leaving them like this works but gives
  the user no way to interact with or modify the variables outside the graph.

  This method searches for variables and lifts them out as regular variable
  objects when possible, indicating to the FuncGraph that they are captures.

  Args:
    graph: The FuncGraph to lift variables from.
    variable_holder: A VariableHolder to record the lifted variables in.
  """
  with graph.as_default():
    global_collection_variables = ops.get_collection(
        ops.GraphKeys.GLOBAL_VARIABLES)
    local_collection_variables = ops.get_collection(
        ops.GraphKeys.LOCAL_VARIABLES)
    existing_captures = {id(c) for c in graph.internal_captures}
    lifted_variables = {}

    def _should_lift_variable(v):
      return ((v._in_graph_mode  # pylint: disable=protected-access
               and v.graph.building_function)
              and isinstance(v, resource_variable_ops.BaseResourceVariable)
              and id(v.handle) not in existing_captures)

    for old_variable in global_collection_variables:
      if _should_lift_variable(old_variable):
        new_variable = _lift_single_variable(
            old_variable, graph, variable_holder)
        lifted_variables[id(old_variable)] = new_variable
        existing_captures.add(id(old_variable.handle))

    for old_variable in local_collection_variables:
      if _should_lift_variable(old_variable):
        new_variable = _lift_single_variable(
            old_variable, graph, variable_holder)
        lifted_variables[id(old_variable)] = new_variable
        existing_captures.add(id(old_variable.handle))
        if new_variable._in_graph_mode:  # pylint: disable=protected-access
          outer_graph = new_variable.graph
          # Variables are added to the global collection by default. In this
          # case we only want the variable in the local collection, so we'll pop
          # it out.
          global_collection = outer_graph.get_collection_ref(
              ops.GraphKeys.GLOBAL_VARIABLES)
          global_collection.remove(new_variable)
          outer_graph.add_to_collection(
              ops.GraphKeys.LOCAL_VARIABLES, new_variable)

    # Update the FuncGraph's collections, partly for the user and partly so this
    # function is idempotent when it runs again in prune() calls.
    for collection_name in [
        ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
    ]:
      mutable_collection = ops.get_collection_ref(collection_name)
      for index, current in enumerate(mutable_collection):
        mutable_collection[index] = lifted_variables.get(id(current), current)
        if not resource_variable_ops.is_resource_variable(
            mutable_collection[index]):
          logging.log_first_n(
              logging.WARN,
              "Unable to create a python object for variable {} because it is "
              "a reference variable. It may not be visible to training APIs. "
              "If this is a problem, consider rebuilding the SavedModel after "
              "running tf.compat.v1.enable_resource_variables().".format(
                  mutable_collection[index]),
              5)
Example #32
0
    def __init__(self,
                 max_tokens=None,
                 num_oov_indices=1,
                 mask_token=None,
                 oov_token=-1,
                 vocabulary=None,
                 idf_weights=None,
                 invert=False,
                 output_mode="int",
                 sparse=False,
                 pad_to_max_tokens=False,
                 **kwargs):
        allowed_dtypes = [tf.int64]

        # Support deprecated args for this layer.
        if "max_values" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "max_values is deprecated, use max_tokens instead.", 1)
            max_tokens = kwargs["max_values"]
            del kwargs["max_values"]
        if "mask_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "mask_value is deprecated, use mask_token instead.", 1)
            mask_token = kwargs["mask_value"]
            del kwargs["mask_value"]
        if "oov_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "oov_value is deprecated, use oov_token instead.", 1)
            oov_token = kwargs["oov_value"]
            del kwargs["oov_value"]

        if "dtype" in kwargs and kwargs["dtype"] not in allowed_dtypes:
            raise ValueError(
                "The value of the dtype argument for IntegerLookup may "
                "only be one of %s." % (allowed_dtypes, ))

        if "dtype" not in kwargs:
            kwargs["dtype"] = tf.int64

        # If max_tokens is set, the token must be greater than 1 - otherwise we
        # are creating a 0-element vocab, which doesn't make sense.
        if max_tokens is not None and max_tokens <= 1:
            raise ValueError(
                f"If `max_tokens` is set for `IntegerLookup`, it must be "
                f"greater than 1. Received: max_tokens={max_tokens}.")

        if num_oov_indices < 0:
            raise ValueError(
                f"The value of `num_oov_indices` argument for `IntegerLookup` "
                f"must >= 0. Received num_oov_indices="
                f"{num_oov_indices}.")

        # Make sure mask and oov are of the dtype we want.
        mask_token = None if mask_token is None else np.int64(mask_token)
        oov_token = None if oov_token is None else np.int64(oov_token)

        super(IntegerLookup,
              self).__init__(max_tokens=max_tokens,
                             num_oov_indices=num_oov_indices,
                             mask_token=mask_token,
                             oov_token=oov_token,
                             vocabulary=vocabulary,
                             idf_weights=idf_weights,
                             invert=invert,
                             output_mode=output_mode,
                             sparse=sparse,
                             pad_to_max_tokens=pad_to_max_tokens,
                             **kwargs)
        base_preprocessing_layer.keras_kpl_gauge.get_cell("IntegerLookup").set(
            True)
Example #33
0
 def LogErrorOnce(msg):
   logging.log_first_n(logging.ERROR, msg, 1)
Example #34
0
 def disable_partitioned_variables(getter, *args, **kwargs):
   if kwargs.pop("partitioner", None) is not None:
     tf_logging.log_first_n(
         tf_logging.WARN, "Partitioned variables are disabled when using "
         "DistributionStrategy.", 1)
   return getter(*args, **kwargs)
Example #35
0
 def _use_merge_call(self):
     logging.log_first_n(
         logging.WARN, "XLA is not supported for multi-worker "
         "strategy.", 1)
     return True
Example #36
0
    def __init__(self,
                 max_tokens=None,
                 num_oov_indices=1,
                 mask_token=None,
                 oov_token=-1,
                 vocabulary=None,
                 vocabulary_dtype="int64",
                 idf_weights=None,
                 invert=False,
                 output_mode="int",
                 sparse=False,
                 pad_to_max_tokens=False,
                 **kwargs):
        if not tf.dtypes.as_dtype(vocabulary_dtype).is_integer:
            raise ValueError("`vocabulary_dtype` must be an integer dtype. "
                             f"Received: {vocabulary_dtype}")

        # Legacy versions of the IntegerLookup layer set layer dtype to int64,
        # instead of the output type. If we see this and output mode is not "int",
        # clear the setting so we don't switch types for old SavedModels.
        if output_mode != "int" and "dtype" in kwargs and (
                kwargs["dtype"] == tf.int64 or kwargs["dtype"] == "int64"):
            del kwargs["dtype"]

        # Support deprecated args for this layer.
        if "max_values" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "max_values is deprecated, use max_tokens instead.", 1)
            max_tokens = kwargs["max_values"]
            del kwargs["max_values"]
        if "mask_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "mask_value is deprecated, use mask_token instead.", 1)
            mask_token = kwargs["mask_value"]
            del kwargs["mask_value"]
        if "oov_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "oov_value is deprecated, use oov_token instead.", 1)
            oov_token = kwargs["oov_value"]
            del kwargs["oov_value"]

        # If max_tokens is set, the token must be greater than 1 - otherwise we
        # are creating a 0-element vocab, which doesn't make sense.
        if max_tokens is not None and max_tokens <= 1:
            raise ValueError(
                f"If `max_tokens` is set for `IntegerLookup`, it must be "
                f"greater than 1. Received: max_tokens={max_tokens}.")

        if num_oov_indices < 0:
            raise ValueError(
                f"The value of `num_oov_indices` argument for `IntegerLookup` "
                f"must >= 0. Received num_oov_indices="
                f"{num_oov_indices}.")

        # Make sure mask and oov are of the dtype we want.
        mask_token = None if mask_token is None else np.int64(mask_token)
        oov_token = None if oov_token is None else np.int64(oov_token)

        super().__init__(max_tokens=max_tokens,
                         num_oov_indices=num_oov_indices,
                         mask_token=mask_token,
                         oov_token=oov_token,
                         vocabulary=vocabulary,
                         vocabulary_dtype=vocabulary_dtype,
                         idf_weights=idf_weights,
                         invert=invert,
                         output_mode=output_mode,
                         sparse=sparse,
                         pad_to_max_tokens=pad_to_max_tokens,
                         **kwargs)
        base_preprocessing_layer.keras_kpl_gauge.get_cell("IntegerLookup").set(
            True)
Example #37
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None,
           fn_output_signature=None):
    """Transforms `elems` by applying `fn` to each element unstacked on axis 0.

  See also `tf.scan`.

  `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
  calls `fn` to transform each element; and then stacks the transformed
  values back together.

  #### Mapping functions with single-Tensor inputs and outputs

  If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
  then `map_fn(fn, elems)` is equivalent to
  `tf.stack([fn(elem) for elem in tf.unstack(elems)])`.  E.g.:

  >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.

  #### Mapping functions with multi-arity inputs and outputs

  `map_fn` also supports functions with multi-arity inputs and outputs:

  * If `elems` is a tuple (or nested structure) of tensors, then those tensors
    must all have the same outer-dimension size (`num_elems`); and `fn` is
    used to transform each tuple (or structure) of corresponding slices from
    `elems`.  E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
    transform each tuple of slices `(t1[i], t2[i], t3[i])`
    (where `0 <= i < num_elems`).

  * If `fn` returns a tuple (or nested structure) of tensors, then the
    result is formed by stacking corresponding elements from those structures.

  #### Specifying `fn`'s output signature

  If `fn`'s input and output signatures are different, then the output
  signature must be specified using `fn_output_signature`.  (The input and
  output signatures are differ if their structures, dtypes, or tensor types do
  not match).  E.g.:

  >>> tf.map_fn(fn=tf.strings.length,  # input & output have different dtypes
  ...           elems=tf.constant(["hello", "moon"]),
  ...           fn_output_signature=tf.int32)
  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
  >>> tf.map_fn(fn=tf.strings.join,  # input & output have different structures
  ...           elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
  ...           fn_output_signature=tf.string)
  <tf.Tensor: shape=(2,), dtype=string,
   numpy=array([b'TheDog', b'ACat'], dtype=object)>

  `fn_output_signature` can be specified using any of the following:

  * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
  * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
  * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
  * A (possibly nested) tuple, list, or dict containing the above types.

  #### RaggedTensors

  `map_fn` supports `tf.RaggedTensor` inputs and outputs.  In particular:

  * If `elems` is a `RaggedTensor`, then `fn` will be called with each
    row of that ragged tensor.
    * If `elems` has only one ragged dimension, then the values passed to
      `fn` will be `tf.Tensor`s.
    * If `elems` has multiple ragged dimensions, then the values passed to
      `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.

  * If the result of `map_fn` should be a `RaggedTensor`, then use a
    `tf.RaggedTensorSpec` to specify `fn_output_signature`.
    * If `fn` returns `tf.Tensor`s with varying sizes, then use a
      `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
      single ragged tensor (which will have ragged_rank=1).
    * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
      with the same `ragged_rank`.

  >>> # Example: RaggedTensor input
  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>

  >>> # Example: RaggedTensor output
  >>> elems = tf.constant([3, 5, 0, 2])
  >>> tf.map_fn(tf.range, elems,
  ...           fn_output_signature=tf.RaggedTensorSpec(shape=[None],
  ...                                                   dtype=tf.int32))
  <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `RaggedTensor`.  If you wish to map a function over the
  individual values, then you should use:

  * `tf.ragged.map_flat_values(fn, rt)`
    (if fn is expressible as TensorFlow ops)
  * `rt.with_flat_values(map_fn(fn, rt.flat_values))`
    (otherwise)

  E.g.:

  >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
  >>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
  <tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>

  #### SparseTensors

  `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs.  In particular:

  * If `elems` is a `SparseTensor`, then `fn` will be called with each row
    of that sparse tensor. In particular, the value passed to `fn` will be a
    `tf.sparse.SparseTensor` with one fewer dimension than `elems`.

  * If the result of `map_fn` should be a `SparseTensor`, then use a
    `tf.SparseTensorSpec` to specify `fn_output_signature`.  The individual
    `SparseTensor`s returned by `fn` will be stacked into a single
    `SparseTensor` with one more dimension.

  >>> # Example: SparseTensor input
  >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
  >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
  <tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>

  >>> # Example: SparseTensor output
  >>> tf.sparse.to_dense(
  ...     tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
  ...               fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
  <tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
    array([[[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]],
           [[1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.]]], dtype=float32)>

  Note: `map_fn` should only be used if you need to map a function over the
  *rows* of a `SparseTensor`.  If you wish to map a function over the nonzero
  values, then you should use:

  * If the function is expressible as TensorFlow ops, use:
    ```python
    tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
    ```
  * Otherwise, use:
    ```python
    tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),
                           st.dense_shape)
    ```

  #### `map_fn` vs. vectorized operations

  `map_fn` will apply the operations used by `fn` to each element of `elems`,
  resulting in `O(elems.shape[0])` total operations.  This is somewhat
  mitigated by the fact that `map_fn` can process elements in parallel.
  However, a transform expressed using `map_fn` is still typically less
  efficient than an equivalent transform expressed using vectorized operations.

  `map_fn` should typically only be used if one of the following is true:

  * It is difficult or expensive to express the desired transform with
    vectorized operations.
  * `fn` creates large intermediate values, so an equivalent vectorized
    transform would take too much memory.
  * Processing elements in parallel is more efficient than an equivalent
    vectorized transform.
  * Efficiency of the transform is not critical, and using `map_fn` is
    more readable.

  E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
  across `elems` could be rewritten more efficiently using vectorized ops:

  >>> elems = tf.constant([3, 5, 2])
  >>> tf.range(3) + tf.expand_dims(elems, 1)
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>

  In some cases, `tf.vectorized_map` can be used to automatically convert a
  function to a vectorized equivalent.

  #### Eager execution

  When executing eagerly, `map_fn` does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.function` decorator:

  >>> fn=lambda t: tf.range(t, t + 3)
  >>> @tf.function
  ... def func(elems):
  ...   return tf.map_fn(fn, elems, parallel_iterations=3)
  >>> func(tf.constant([3, 5, 2]))
  <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[3, 4, 5],
           [5, 6, 7],
           [2, 3, 4]], dtype=int32)>


  Note: if you use the `tf.function` decorator, any non-TensorFlow Python
  code that you may have written in your function won't get executed. See
  `tf.function` for more  details. The recommendation would be to debug without
  `tf.function` but switch to it to get performance benefits of running `map_fn`
  in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will have
      the same (possibly nested) structure as `elems`.  Its output must have the
      same structure as `fn_output_signature` if one is provided; otherwise it
      must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which will
      be unstacked along their first dimension.  `fn` will be applied to the
      nested sequence of the resulting slices.  `elems` may include ragged and
      sparse tensors. `elems` must consist of at least one tensor.
    dtype: Deprecated: Equivalent to `fn_output_signature`.
    parallel_iterations: (optional) The number of iterations allowed to run in
      parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) False disables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.
    fn_output_signature: The output signature of `fn`. Must be specified if
      `fn`'s input and output signatures are different (i.e., if their
      structures, dtypes, or tensor types do not match).
      `fn_output_signature` can be specified using any of the following:

      * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
      * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
      * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`)
      * A (possibly nested) tuple, list, or dict containing the above types.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor stacks the
    results of applying `fn` to tensors unstacked from `elems` along the first
    dimension, from first to last.  The result may include ragged and sparse
    tensors.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `fn_output_signature` do not match.
    ValueError: if the lengths of the output of `fn` and `fn_output_signature`
      do not match, or if the `elems` does not contain any tensor.

  Examples:

    >>> elems = np.array([1, 2, 3, 4, 5, 6])
    >>> tf.map_fn(lambda x: x * x, elems)
    <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1,  4,  9, 16, 25, 36])>

    >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
    <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1,  2, -3])>

    >>> elems = np.array([1, 2, 3])
    >>> tf.map_fn(lambda x: (x, -x), elems,
    ...          fn_output_signature=(tf.int64, tf.int64))
    (<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
  """
    # This function uses a `while_loop` to call `fn` on each value of the input
    # tensor(s) (unstacked on dimension 0).  The following sequence of variables
    # are used to transform the input tensor(s) (`elems`) into the output
    # tensor(s) (`result`):
    #
    #   - Preparing and unstacking input values for the while_loop:
    #     - elems: The input tensor(s) to map_fn. May include composite tensors.
    #     - elems_flat: Flattened list of tensors from elems (using nest.flatten)
    #                   May include composite tensors.
    #     - elems_batchable: Concatenation of "batchable tensor lists" for each
    #                        tensor in elems_flat.  This "boxes" composite tensors
    #                        into sliceable tf.Tensor objects.  For more info see:
    #                        TensorSpec._to_batched_tensor_list
    #     - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
    #                           in elems_batchable into elems_value_batchable.
    #
    #   - Calling `fn` on each unstacked value in the body of the while_loop:
    #     - elems_value_batchable: Single unstacked value from elems_batchable.
    #     - elems_value_flat: Single unstacked value from elems_flat,
    #                         constructed from elems_value_batchable (using
    #                         TensorSpec._from_tensor_list).
    #     - elems_value: Single unstacked value from elems (the input to fn).
    #     - result_value: Result of calling `fn(elems_value)`.  May contain
    #                     composite tensors.
    #     - result_value_flat: Flattened list of tensors from result_value.
    #                          May contain composite tensors.
    #     - result_value_batchable: Concatenation of batchable tensor lists for
    #                               each tensor in result_value_flat
    #                               (using TensorSpec._to_tensor_list).
    #
    #   - Collecting and stacking output values from the while_loop:
    #     - result_batchable_ta: List of TensorArrays used to stack each tensor
    #                            ta result_value_batchable into result_batchable.
    #     - result_batchable: Stacked tensors from result_batchable_ta.
    #     - result_flat: Flat list of tensors for the result, constructed from
    #                    results bactchable (using TensorSpec._from_tensor_list).
    #     - result: Structured result value packed from results flat
    #               (using nest.pack_sequence_as).

    if fn_output_signature is None:
        fn_output_signature = dtype

    if not callable(fn):
        raise TypeError(f"The provided function {fn.__name__} is not callable."
                        "fn must be callable.")

    in_graph_mode = not context.executing_eagerly()
    # Set the default number of parallel_iterations depending on graph/eager mode.
    if in_graph_mode and not parallel_iterations:
        parallel_iterations = 10
    elif not in_graph_mode and not parallel_iterations:
        parallel_iterations = 1
    elif not in_graph_mode and parallel_iterations > 1:
        logging.log_first_n(
            logging.WARN, "Setting parallel_iterations > 1 has no "
            "effect when executing eagerly. Consider calling map_fn"
            " with tf.function to execute fn in "
            "parallel.", 1)
        parallel_iterations = 1

    # Flatten the input tensors, and get the TypeSpec for each one.
    elems_flat = nest.flatten(elems)

    # Check in case this is an empty list
    if len(elems_flat) == 0:
        raise ValueError(
            "elems must be a Tensor or (possibly nested) sequence of Tensors. "
            "Got {}, which does not contain any Tensors.".format(elems))

    elems_flat_signature = [
        type_spec.type_spec_from_value(e) for e in elems_flat
    ]
    elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)

    # Flatten fn's output signature.
    if fn_output_signature is None:
        # If fn_output_signature was not specified, then assume that it matches the
        # input signature.
        result_flat_signature = [
            _most_general_compatible_type(s)._unbatch()  # pylint: disable=protected-access
            for s in elems_flat_signature
        ]
        result_unflatten = elems_unflatten
    else:
        result_flat_signature = [
            _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
        ]
        result_unflatten = lambda x: nest.pack_sequence_as(
            fn_output_signature, x)

    with ops.name_scope(name, "map", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        elems_flat = [
            ops.convert_to_tensor_or_composite(t, name="elem")
            for t in elems_flat
        ]

        # Check that inputs are not scalars.
        first_elem = elems_flat[0]
        if hasattr(first_elem, "shape"):
            elems_static_shape = first_elem.shape
            if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
                raise ValueError(
                    "Elements in elems must be 1+ dimensional Tensors, not scalars"
                )

        # Box any composite tensors into tensor lists.
        elems_batchable = _elems_flat_to_batchable(elems_flat)

        # Find the number of iterations, n.  (may be known statically.)
        n_static = tensor_shape.Dimension(
            tensor_shape.dimension_value(
                elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
        for tensor in elems_batchable[1:]:
            n_static.assert_is_compatible_with(
                tensor_shape.Dimension(
                    tensor_shape.dimension_value(
                        tensor.get_shape().with_rank_at_least(1)[0])))
        n = n_static.value or array_ops.shape(elems_batchable[0])[0]

        # Convert elems to tensor array.
        # TODO(edloper): Should we set infer_shape=False for composite tensors?
        elems_batchable_ta = [
            tensor_array_ops.TensorArray(dtype=t.dtype,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=True)
            for t in elems_batchable
        ]
        # Unpack elements
        elems_batchable_ta = [
            ta.unstack(t)
            for (ta, t) in zip(elems_batchable_ta, elems_batchable)
        ]

        i = constant_op.constant(0)

        # Prepare result tensor array.
        # TODO(edloper): Should we set infer_shape=False for composite tensors?
        result_batchable_tensor_spec = (
            _result_flat_signature_to_batchable_tensor_spec(
                result_flat_signature))
        result_batchable_ta = []
        for spec in result_batchable_tensor_spec:
            result_batchable_ta.append(
                tensor_array_ops.TensorArray(dtype=spec.dtype,
                                             size=n,
                                             dynamic_size=False,
                                             infer_shape=infer_shape,
                                             element_shape=spec.shape))

        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if fn_output_signature and result_value structure don't match
        ValueType: if fn_output_signature and result_value lengths don't match
      """
            elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
            elems_value_flat = _elems_value_batchable_to_flat(
                elems_value_batchable, elems_flat_signature)
            elems_value = elems_unflatten(elems_value_flat)
            ag_ctx = autograph_ctx.control_status_ctx()
            autographed_fn = autograph.tf_convert(fn, ag_ctx)
            result_value = autographed_fn(elems_value)
            nest.assert_same_structure(fn_output_signature or elems,
                                       result_value)
            result_value_flat = nest.flatten(result_value)
            result_value_batchable = _result_value_flat_to_batchable(
                result_value_flat, result_flat_signature)
            tas = [
                ta.write(i, value)
                for (ta, value) in zip(tas, result_value_batchable)
            ]
            return (i + 1, tas)

        _, r_a = control_flow_ops.while_loop(
            lambda i, _: i < n,
            compute, (i, result_batchable_ta),
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            swap_memory=swap_memory,
            maximum_iterations=n)
        result_batchable = [r.stack() for r in r_a]

        # Update each output tensor w/ static shape info about the outer dimension.
        for r in result_batchable:
            r.set_shape(
                tensor_shape.TensorShape(n_static).concatenate(
                    r.get_shape()[1:]))

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:
            varscope.set_caching_device(None)

        result_flat = _result_batchable_to_flat(result_batchable,
                                                result_flat_signature,
                                                n_static)
        result = result_unflatten(result_flat)
        return result
Example #38
0
def map_fn(fn,
           elems,
           dtype=None,
           parallel_iterations=None,
           back_prop=True,
           swap_memory=False,
           infer_shape=True,
           name=None):
    """map on the list of tensors unpacked from `elems` on dimension 0.

  The simplest version of `map_fn` repeatedly applies the callable `fn` to a
  sequence of elements from first to last. The elements are made of the
  tensors unpacked from `elems`. `dtype` is the data type of the return
  value of `fn`. Users must provide `dtype` if it is different from
  the data type of `elems`.

  Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
  of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.

  This method also allows multi-arity `elems` and output of `fn`.  If `elems`
  is a (possibly nested) list or tuple of tensors, then each of these tensors
  must have a matching first (unpack) dimension.  The signature of `fn` may
  match the structure of `elems`.  That is, if `elems` is
  `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
  `fn = lambda (t1, [t2, t3, [t4, t5]]):`.

  Furthermore, `fn` may emit a different structure than its input.  For example,
  `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`.  In this case,
  the `dtype` parameter is not optional: `dtype` must be a type or (possibly
  nested) tuple of types matching the output of `fn`.

  To apply a functional operation to the nonzero elements of a SparseTensor
  one of the following methods is recommended. First, if the function is
  expressible as TensorFlow ops, use

  ```python
    result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
  ```

  If, however, the function is not expressible as a TensorFlow op, then use

  ```python
  result = SparseTensor(
    input.indices, map_fn(fn, input.values), input.dense_shape)
  ```

  instead.

  When executing eagerly, map_fn does not execute in parallel even if
  `parallel_iterations` is set to a value > 1. You can still get the
  performance benefits of running a function in parallel by using the
  `tf.contrib.eager.defun` decorator,

  ```python
  # Assume the function being used in map_fn is fn.
  # To ensure map_fn calls fn in parallel, use the defun decorator.
  @tf.contrib.eager.defun
  def func(tensor):
    return tf.map_fn(fn, tensor)
  ```

  Note that if you use the defun decorator, any non-TensorFlow Python code
  that you may have written in your function won't get executed. See
  `tf.contrib.eager.defun` for more details. The recommendation would be to
  debug without defun but switch to defun to get performance benefits of
  running map_fn in parallel.

  Args:
    fn: The callable to be performed.  It accepts one argument, which will
      have the same (possibly nested) structure as `elems`.  Its output
      must have the same structure as `dtype` if one is provided, otherwise
      it must have the same structure as `elems`.
    elems: A tensor or (possibly nested) sequence of tensors, each of which
      will be unpacked along their first dimension.  The nested sequence
      of the resulting slices will be applied to `fn`.
    dtype: (optional) The output type(s) of `fn`.  If `fn` returns a structure
      of Tensors differing from the structure of `elems`, then `dtype` is not
      optional and must have the same structure as the output of `fn`.
    parallel_iterations: (optional) The number of iterations allowed to run
      in parallel. When graph building, the default value is 10. While executing
      eagerly, the default value is set to 1.
    back_prop: (optional) True enables support for back propagation.
    swap_memory: (optional) True enables GPU-CPU memory swapping.
    infer_shape: (optional) False disables tests for consistent output shapes.
    name: (optional) Name prefix for the returned tensors.

  Returns:
    A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    results of applying `fn` to tensors unpacked from `elems` along the first
    dimension, from first to last.

  Raises:
    TypeError: if `fn` is not callable or the structure of the output of
      `fn` and `dtype` do not match, or if elems is a SparseTensor.
    ValueError: if the lengths of the output of `fn` and `dtype` do not match.

  Examples:
    ```python
    elems = np.array([1, 2, 3, 4, 5, 6])
    squares = map_fn(lambda x: x * x, elems)
    # squares == [1, 4, 9, 16, 25, 36]
    ```

    ```python
    elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
    alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
    # alternate == [-1, 2, -3]
    ```

    ```python
    elems = np.array([1, 2, 3])
    alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
    # alternates[0] == [1, 2, 3]
    # alternates[1] == [-1, -2, -3]
    ```
  """
    if not callable(fn):
        raise TypeError("fn must be callable.")

    if isinstance(elems, sparse_tensor.SparseTensor):
        raise TypeError(
            "To perform a map on the values of a sparse tensor use either "
            " SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
            " SparseTensor(input.indices, map_fn(fn, input.values), "
            "input.dense_shape)")

    in_graph_mode = not context.executing_eagerly()
    # Set the default number of parallel_iterations depending on graph/eager mode.
    if in_graph_mode and not parallel_iterations:
        parallel_iterations = 10
    elif not in_graph_mode and not parallel_iterations:
        parallel_iterations = 1

    if not in_graph_mode and parallel_iterations > 1:
        logging.log_first_n(
            logging.WARN, "Setting parallel_iterations > 1 has no "
            "effect when executing eagerly. Consider calling map_fn"
            " with tf.contrib.eager.defun to execute fn in "
            "parallel.", 1)
        parallel_iterations = 1

    input_is_sequence = nest.is_sequence(elems)
    input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]

    def input_pack(x):
        return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]

    if dtype is None:
        output_is_sequence = input_is_sequence
        output_flatten = input_flatten
        output_pack = input_pack
    else:
        output_is_sequence = nest.is_sequence(dtype)
        output_flatten = lambda x: nest.flatten(
            x) if output_is_sequence else [x]

        def output_pack(x):
            return (nest.pack_sequence_as(dtype, x)
                    if output_is_sequence else x[0])

    elems_flat = input_flatten(elems)

    with ops.name_scope(name, "map", elems_flat):
        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode:
            # Any get_variable calls in fn will cache the first call locally
            # and not issue repeated network I/O requests for each iteration.
            varscope = vs.get_variable_scope()
            varscope_caching_device_was_none = False
            if varscope.caching_device is None:
                # TODO(ebrevdo): Change to using colocate_with here and in other
                # methods.
                varscope.set_caching_device(lambda op: op.device)
                varscope_caching_device_was_none = True

        elems_flat = [
            ops.convert_to_tensor(elem, name="elem") for elem in elems_flat
        ]

        dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
        dtype_flat = output_flatten(dtype)

        # Convert elems to tensor array. n may be known statically.
        static_shape = elems_flat[0].shape
        if static_shape.ndims is not None and static_shape.ndims < 1:
            if len(elems_flat) == 1:
                raise ValueError(
                    "elems must be a 1+ dimensional Tensor, not a scalar")
            else:
                raise ValueError(
                    "elements in elems must be 1+ dimensional Tensors, not scalars"
                )
        n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]

        # TensorArrays are always flat
        elems_ta = [
            tensor_array_ops.TensorArray(dtype=elem.dtype,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=True)
            for elem in elems_flat
        ]
        # Unpack elements
        elems_ta = [
            elem_ta.unstack(elem)
            for elem_ta, elem in zip(elems_ta, elems_flat)
        ]

        i = constant_op.constant(0)

        accs_ta = [
            tensor_array_ops.TensorArray(dtype=dt,
                                         size=n,
                                         dynamic_size=False,
                                         infer_shape=infer_shape)
            for dt in dtype_flat
        ]

        def compute(i, tas):
            """The loop body of map_fn.

      Args:
        i: the loop counter
        tas: the flat TensorArray accumulator list

      Returns:
        (i + 1, tas): the updated counter + updated TensorArrays

      Raises:
        TypeError: if dtype and packed_fn_values structure do not match
        ValueType: if dtype and packed_fn_values lengths do not match
      """
            packed_values = input_pack(
                [elem_ta.read(i) for elem_ta in elems_ta])
            packed_fn_values = fn(packed_values)
            nest.assert_same_structure(dtype or elems, packed_fn_values)
            flat_fn_values = output_flatten(packed_fn_values)
            tas = [
                ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)
            ]
            return (i + 1, tas)

        _, r_a = control_flow_ops.while_loop(
            lambda i, _: i < n,
            compute, (i, accs_ta),
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            swap_memory=swap_memory,
            maximum_iterations=n)
        results_flat = [r.stack() for r in r_a]

        n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0]
        for elem in elems_flat[1:]:
            n_static.merge_with(elem.get_shape().with_rank_at_least(1)[0])
        for r in results_flat:
            r.set_shape(
                tensor_shape.TensorShape(n_static).concatenate(
                    r.get_shape()[1:]))

        # TODO(akshayka): Remove the in_graph_mode check once caching devices are
        # supported in Eager
        if in_graph_mode and varscope_caching_device_was_none:
            varscope.set_caching_device(None)

        return output_pack(results_flat)
Example #39
0
def warn_first_n(msg, *args, **kwargs):
    logging.log_first_n(logging.WARNING, msg, *args, **kwargs)
Example #40
0
    def __init__(self,
                 max_tokens=None,
                 num_oov_indices=1,
                 mask_token=None,
                 oov_token=-1,
                 vocabulary=None,
                 invert=False,
                 output_mode=index_lookup.INT,
                 sparse=False,
                 pad_to_max_tokens=False,
                 **kwargs):
        allowed_dtypes = [tf.int64]

        # Support deprecated args for this layer.
        if "max_values" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "max_values is deprecated, use max_tokens instead.", 1)
            max_tokens = kwargs["max_values"]
            del kwargs["max_values"]
        if "mask_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "mask_value is deprecated, use mask_token instead.", 1)
            mask_token = kwargs["mask_value"]
            del kwargs["mask_value"]
        if "oov_value" in kwargs:
            logging.log_first_n(
                logging.WARN,
                "oov_value is deprecated, use oov_token instead.", 1)
            oov_token = kwargs["oov_value"]
            del kwargs["oov_value"]

        if "dtype" in kwargs and kwargs["dtype"] not in allowed_dtypes:
            raise ValueError(
                "The value of the dtype argument for IntegerLookup may "
                "only be one of %s." % (allowed_dtypes, ))

        if "dtype" not in kwargs:
            kwargs["dtype"] = tf.int64

        # If max_tokens is set, the token must be greater than 1 - otherwise we
        # are creating a 0-element vocab, which doesn't make sense.
        if max_tokens is not None and max_tokens <= 1:
            raise ValueError("If set, max_tokens must be greater than 1. "
                             "You passed %s" % (max_tokens, ))

        if num_oov_indices < 0:
            raise ValueError(
                "num_oov_indices must be greater than or equal to 0. You passed %s"
                % (num_oov_indices, ))

        super(IntegerLookup,
              self).__init__(max_tokens=max_tokens,
                             num_oov_indices=num_oov_indices,
                             mask_token=mask_token,
                             oov_token=oov_token,
                             vocabulary=vocabulary,
                             invert=invert,
                             output_mode=output_mode,
                             sparse=sparse,
                             pad_to_max_tokens=pad_to_max_tokens,
                             **kwargs)
        base_preprocessing_layer.keras_kpl_gauge.get_cell("IntegerLookup").set(
            True)
Example #41
0
def is_whitelisted_for_graph(o):
    """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)
    if not hasattr(m, '__name__'):
        logging.vlog(1, '%s is NOT whitelisted for graph: unknown module name',
                     o)
        return False

    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
        if m.__name__.startswith(prefix):
            logging.vlog(1, '%s is whitelisted: name starts with "%s"', o,
                         prefix)
            return True

    if hasattr(o, 'autograph_info__'):
        return True

    if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o)
            and hasattr(o, '__call__') and hasattr(o, '__class__')):
        # Callable objects: whitelisted if their __call__ method is.
        retval = is_whitelisted_for_graph(o.__call__)
        logging.vlog(1, '%s is whitelisted: object __call__ whitelisted', o)
        return retval

    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted_for_graph(owner_class):
                logging.vlog(1, '%s is whitelisted: owner is whitelisted %s',
                             o, owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.log_first_n(
                logging.level_warning(),
                'Entity {} looks like a namedtuple subclass. If it has any custom'
                ' methods, they will not be converted by AutoGraph.'.format(o),
                1)
        logging.vlog(1, '%s is whitelisted: named tuple', o)
        return True

    logging.vlog(1, '%s is NOT whitelisted for graph', o)
    return False
Example #42
0
    def gradient(self,
                 target,
                 sources,
                 output_gradients=None,
                 unconnected_gradients=UnconnectedGradients.NONE):
        """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
        if self._tape is None:
            raise RuntimeError(
                "GradientTape.gradient can only be called once on "
                "non-persistent tapes.")
        if self._recording:
            if not self._persistent:
                self._pop_tape()
            else:
                logging.log_first_n(
                    logging.WARN,
                    "Calling GradientTape.gradient on a persistent "
                    "tape inside its context is significantly less "
                    "efficient than calling it outside the context (it "
                    "causes the gradient ops to be recorded on the "
                    "tape, leading to increased CPU and memory usage). "
                    "Only call GradientTape.gradient inside the "
                    "context if you actually want to trace the "
                    "gradient in order to compute higher order "
                    "derivatives.", 1)

        flat_targets = []
        for t in nest.flatten(target):
            if resource_variable_ops.is_resource_variable(t):
                with self:
                    t = ops.convert_to_tensor(t)
            flat_targets.append(t)

        flat_sources = nest.flatten(sources)
        flat_sources = [_handle_or_self(x) for x in flat_sources]

        if output_gradients is not None:
            output_gradients = [
                None if x is None else ops.convert_to_tensor(x)
                for x in nest.flatten(output_gradients)
            ]

        flat_grad = imperative_grad.imperative_grad(
            self._tape,
            flat_targets,
            flat_sources,
            output_gradients=output_gradients,
            unconnected_gradients=unconnected_gradients)

        if not self._persistent:
            self._tape = None

        grad = nest.pack_sequence_as(sources, flat_grad)
        return grad
Example #43
0
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)
  if not hasattr(m, '__name__'):
    logging.vlog(1, '%s is NOT whitelisted for graph: unknown module name', o)
    return False

  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      logging.vlog(1, '%s is whitelisted: name starts with "%s"', o, prefix)
      return True

  if hasattr(o, 'autograph_info__'):
    return True

  if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and
      hasattr(o, '__call__') and hasattr(o, '__class__')):
    # Callable objects: whitelisted if their __call__ method is.
    retval = is_whitelisted_for_graph(o.__call__)
    logging.vlog(1, '%s is whitelisted: object __call__ whitelisted', o)
    return retval

  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.vlog(1, '%s is whitelisted: owner is whitelisted %s', o,
                     owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.log_first_n(
          logging.level_warning(),
          'Entity {} looks like a namedtuple subclass. If it has any custom'
          ' methods, they will not be converted by AutoGraph.'.format(o), 1)
    logging.vlog(1, '%s is whitelisted: named tuple', o)
    return True

  logging.vlog(1, '%s is NOT whitelisted for graph', o)
  return False