示例#1
0
def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
  """
  # TODO(allenl, rohanj): Map generic resources rather than just variables.
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = {}
  resource_map = {}
  for obj in accessible_objects:
    if resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
  return object_map, resource_map
  def __init__(self, var, slice_spec, name):
    self._var_device = var.device
    self._var_shape = var.shape
    if isinstance(var, ops.Tensor):
      self.handle_op = var.op.inputs[0]
      tensor = var
    elif resource_variable_ops.is_resource_variable(var):

      def _read_variable_closure(v):
        def f():
          with ops.device(v.device):
            x = v.read_value()
            # To allow variables placed on non-CPU devices to be checkpointed,
            # we copy them to CPU on the same machine first.
            with ops.device("/device:CPU:0"):
              return array_ops.identity(x)
        return f

      self.handle_op = var.handle
      tensor = _read_variable_closure(var)
    else:
      raise ValueError(
          "Saveable is neither a resource variable nor a read operation."
          " Got: %s" % repr(var))
    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
                                    dtype=var.dtype, device=var.device)
    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
示例#3
0
def _write_object_proto(obj, proto, asset_file_def_index):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    if not obj.name.endswith(":0"):
      raise ValueError("Cowardly refusing to save variable %s because of"
                       " unexpected suffix which won't be restored.")
    proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.synchronization = obj.synchronization.value
    proto.variable.aggregation = obj.aggregation.value
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.bare_concrete_function.CopyFrom(
        function_serialization.serialize_bare_concrete_function(obj))
  elif isinstance(obj, _CapturedConstant):
    proto.constant.operation = obj.graph_tensor.op.name
  elif isinstance(obj, tracking.CapturableResource):
    proto.resource.device = obj._resource_device  # pylint: disable=protected-access
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
示例#4
0
 def _get_tensor_from_node(self, node_id):
   obj = self._nodes[node_id]
   if resource_variable_ops.is_resource_variable(obj):
     return obj.handle
   elif isinstance(obj, tracking.TrackableAsset):
     return obj.asset_path.handle
   raise ValueError("Can't convert node %s to tensor" % (type(obj)))
示例#5
0
def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()
  saveable_view.fill_object_graph_proto(proto)

  node_ids = util.ObjectIdentityDictionary()
  for i, obj in enumerate(saveable_view.nodes):
    node_ids[obj] = i
    if resource_variable_ops.is_resource_variable(obj):
      node_ids[obj.handle] = i
    elif isinstance(obj, tracking.TrackableAsset):
      node_ids[obj.asset_path.handle] = i

  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
    _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
示例#6
0
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
  """Helper function for creating a slot variable."""

  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
  # scope.
  current_partitioner = variable_scope.get_variable_scope().partitioner
  variable_scope.get_variable_scope().set_partitioner(None)
  # When init from val instead of callable initializer, the shape is expected to
  # be None, not <unknown> or any fully defined shape.
  shape = shape if callable(val) else None
  slot = variable_scope.get_variable(
      scope, initializer=val, trainable=False,
      use_resource=resource_variable_ops.is_resource_variable(primary),
      shape=shape, dtype=dtype,
      validate_shape=validate_shape)
  variable_scope.get_variable_scope().set_partitioner(current_partitioner)

  # pylint: disable=protected-access
  if isinstance(primary, variables.Variable) and primary._save_slice_info:
    # Primary is a partitioned variable, so we need to also indicate that
    # the slot is a partitioned variable.  Slots have the same partitioning
    # as their primaries.
    # For examples when using AdamOptimizer in linear model, slot.name
    # here can be "linear//weights/Adam:0", while primary.op.name is
    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
    # remove "'linear//weight' + '/'" and ':0'.
    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
    slice_info = primary._save_slice_info
    slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
        slice_info.full_name + "/" + real_slot_name,
        slice_info.full_shape[:],
        slice_info.var_offset[:],
        slice_info.var_shape[:]))
  # pylint: enable=protected-access
  return slot
示例#7
0
def _write_object_proto(obj, proto, asset_file_def_index, node_ids):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj, node_ids))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.concrete_function.CopyFrom(
        function_serialization.serialize_concrete_function(obj, node_ids))
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
示例#8
0
  def _restore_checkpoint(self):
    """Load state from checkpoint into the deserialized objects."""
    variables_path = saved_model_utils.get_variables_path(self._export_dir)
    # TODO(andresp): Clean use of private methods of TrackableSaver.
    # pylint: disable=protected-access
    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
    with ops.device("CPU"):
      saver._file_prefix_placeholder = constant_op.constant(variables_path)
    load_status = saver.restore(variables_path)
    load_status.assert_existing_objects_matched()
    checkpoint = load_status._checkpoint

    # When running in eager mode, the `restore` call above has already run and
    # restored the state of trackables, call `position.restore_ops()` will
    # return an empty list as there is nothing left to do. In graph mode, that
    # will return the list of ops that must run to restore the object on that
    # position. We have to wire them in the initializers of the objects so that
    # they get initialized properly when using common practices (e.g. the ones
    # used by ManagedSession) without further user action.
    for object_id, obj in dict(checkpoint.object_by_proto_id).items():
      position = base.CheckpointPosition(checkpoint=checkpoint,
                                         proto_id=object_id)
      restore_ops = position.restore_ops()
      if restore_ops:
        if resource_variable_ops.is_resource_variable(obj):
          obj._initializer_op = restore_ops
        else:
          raise NotImplementedError(
              ("Missing functionality to restore state of object "
               "%r from the checkpoint." % obj))
示例#9
0
    def update(v, g):
      """Apply gradients to a replica variable."""
      assert v is not None

      try:
        # Convert the grad to Tensor or IndexedSlices if necessary.
        g = ops.convert_to_tensor_or_indexed_slices(g)
      except TypeError:
        raise TypeError("Gradient must be convertible to a Tensor"
                        " or IndexedSlices, or None: %s" % g)
      if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
        raise TypeError(
            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)

      if context.executing_eagerly() or (
          resource_variable_ops.is_resource_variable(v) and
          not v._in_graph_mode):  # pylint: disable=protected-access
        scope_name = v.name.split(":")[0]
      else:
        scope_name = v.op.name

      # device_policy is set because non-mirrored tensors will be read in
      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
      # is an example.
      with ops.name_scope("update_" + scope_name):
        return p.update_op(self, g)
示例#10
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
示例#11
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
示例#12
0
 def _get_tensor_from_node(self, node_id):
   """Resolves a node id into a tensor to be captured for a function."""
   with ops.init_scope():
     obj = self._nodes[node_id]
     if resource_variable_ops.is_resource_variable(obj):
       return obj.handle
     elif isinstance(obj, tracking.TrackableAsset):
       return obj.asset_path
     elif tensor_util.is_tensor(obj):
       return obj
     elif isinstance(obj, tracking.CapturableResource):
       # Note: this executes restored functions in the CapturableResource.
       return obj.resource_handle
     raise ValueError("Can't convert node %s to tensor" % (type(obj)))
示例#13
0
def _get_processor(v):
  """The processor of v."""
  if context.executing_eagerly():
    if isinstance(v, ops.Tensor):
      return _TensorProcessor(v)
    else:
      return _DenseResourceVariableProcessor(v)
  if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode:  # pylint: disable=protected-access
    # True if and only if `v` was initialized eagerly.
    return _DenseResourceVariableProcessor(v)
  if v.op.type == "VarHandleOp":
    return _DenseResourceVariableProcessor(v)
  if isinstance(v, variables.Variable):
    return _RefVariableProcessor(v)
  if isinstance(v, ops.Tensor):
    return _TensorProcessor(v)
  raise NotImplementedError("Trying to optimize unsupported type ", v)
示例#14
0
def _get_expanded_variable_list(var_list):
  """Given a list of variables, expands them if they are partitioned.

  Args:
    var_list: A list of variables.

  Returns:
    A list of variables where each partitioned variable is expanded to its
    components.
  """
  returned_list = []
  for variable in var_list:
    if (isinstance(variable, variable_ops.Variable) or
        resource_variable_ops.is_resource_variable(variable)):
      returned_list.append(variable)  # Single variable case.
    else:  # Must be a PartitionedVariable, so convert into a list.
      returned_list.extend(list(variable))
  return returned_list
示例#15
0
def zero_initializer(ref, use_locking=True, name="zero_initializer"):
  """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
  If already initialized, you will get ValueError. This op is intended to
  save memory during initialization.
  Args:
    ref: ref of the tensor need to be zero initialized.
    name: optional name for this operation.
  Returns:
    ref that initialized.
  Raises:
    ValueError: If ref tensor is initialized.
  """
  loader.load_op_library(
      resource_loader.get_path_to_datafile("_variable_ops.so"))
  if resource_variable_ops.is_resource_variable(ref):
    return gen_variable_ops.zero_var_initializer(
        ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
  else:
    return gen_variable_ops.zero_initializer(ref, name=name)
示例#16
0
def assert_global_step(global_step_tensor):
  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.

  Args:
    global_step_tensor: `Tensor` to test.
  """
  if not (isinstance(global_step_tensor, variables.Variable) or
          isinstance(global_step_tensor, ops.Tensor) or
          resource_variable_ops.is_resource_variable(global_step_tensor)):
    raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
                    global_step_tensor)

  if not global_step_tensor.dtype.base_dtype.is_integer:
    raise TypeError('Existing "global_step" does not have integer type: %s' %
                    global_step_tensor.dtype)

  if (global_step_tensor.get_shape().ndims != 0 and
      global_step_tensor.get_shape().is_fully_defined()):
    raise TypeError('Existing "global_step" is not scalar: %s' %
                    global_step_tensor.get_shape())
  def __init__(self, variable):
    """Creates an AutoCastVariable instance.

    Args:
      variable: A floating-point resource variable to wrap.

    Raises:
      ValueError: If `variable` is not a floating-point resource variable
    """
    if not resource_variable_ops.is_resource_variable(variable):
      raise ValueError('variable must be of type tf.ResourceVariable, but got: '
                       '%s' % variable)
    if not variable.dtype.is_floating:
      raise ValueError('variable must be a floating point variable but has '
                       'type: %s' % variable.dtype.name)
    self._variable = variable

    # Delegate to the underlying variable for checkpointing.
    self._gather_saveables_for_checkpoint = (
        self._variable._gather_saveables_for_checkpoint)  # pylint: disable=protected-access
示例#18
0
def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map, asset_info):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
      asset_info: An _AssetInfo tuple describing external assets referenced from
        accessible_objects.
  """
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = util.ObjectIdentityDictionary()
  resource_map = {}
  asset_info = _AssetInfo(
      asset_defs=[],
      asset_initializers_by_resource={},
      asset_filename_map={},
      asset_index={})
  for obj in accessible_objects:
    if isinstance(obj, tracking.TrackableResource):
      new_resource = obj.create_resource()
      resource_map[obj.resource_handle] = new_resource
    elif resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
    elif isinstance(obj, tracking.TrackableAsset):
      _process_asset(obj, asset_info, resource_map)
  return object_map, resource_map, asset_info
示例#19
0
def _format_device(var):
  """Returns the device with an annotation specifying `ResourceVariable`.

  "legacy" means a normal tf.Variable while "resource" means a ResourceVariable.

  For example:
  `(legacy)`
  `(resource)`
  `/job:learner/task:0/device:CPU:* (legacy)`
  `/job:learner/task:0/device:CPU:* (resource)`

  Args:
    var: The Tensorflow Variable or `ResourceVariable` to print.
  """
  if resource_variable_ops.is_resource_variable(var):
    resource_var_annotation = "(resource)"
  else:
    resource_var_annotation = "(legacy)"

  if var.device:
    return "{} {}".format(var.device, resource_var_annotation)
  else:
    return resource_var_annotation
示例#20
0
    def map_resources(self):
        """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
        # Only makes sense when adding to the export Graph
        assert not context.executing_eagerly()
        # TODO(allenl): Handle MirroredVariables and other types of variables which
        # may need special casing.
        object_map = object_identity.ObjectIdentityDictionary()
        resource_map = {}
        asset_info = _AssetInfo(asset_defs=[],
                                asset_initializers_by_resource={},
                                asset_filename_map={},
                                asset_index={})
        for node_id, obj in enumerate(self.nodes):
            if isinstance(obj, tracking.CapturableResource):
                # pylint: disable=protected-access
                with ops.device(obj._resource_device):
                    new_resource = obj._create_resource()
                # pylint: enable=protected-access
                resource_map[obj.resource_handle] = new_resource
                self.captured_tensor_node_ids[obj.resource_handle] = node_id
            elif resource_variable_ops.is_resource_variable(obj):
                new_variable = resource_variable_ops.copy_to_graph_uninitialized(
                    obj)
                object_map[obj] = new_variable
                resource_map[obj.handle] = new_variable.handle
                self.captured_tensor_node_ids[obj.handle] = node_id
            elif isinstance(obj, tracking.TrackableAsset):
                _process_asset(obj, asset_info, resource_map)
                self.captured_tensor_node_ids[obj.asset_path] = node_id

        for concrete_function in self.concrete_functions:
            for capture in concrete_function.captured_inputs:
                if (tensor_util.is_tensor(capture)
                        and capture.dtype not in _UNCOPIABLE_DTYPES
                        and capture not in self.captured_tensor_node_ids):
                    copied_tensor = constant_op.constant(
                        tensor_util.constant_value(capture))
                    node_id = len(self.nodes)
                    node = _CapturedConstant(eager_tensor=capture,
                                             graph_tensor=copied_tensor)
                    self.nodes.append(node)
                    self.node_ids[capture] = node_id
                    self.node_ids[node] = node_id
                    self.captured_tensor_node_ids[capture] = node_id
                    resource_map[capture] = copied_tensor

        return object_map, resource_map, asset_info
示例#21
0
    def doTestBasic(self, use_resource=False, use_callable_params=False):
        if context.executing_eagerly() and not use_resource:
            self.skipTest(
                "Skipping test with use_resource=False and executing eagerly.")
        for i, dtype in enumerate(
            [dtypes.half, dtypes.float32, dtypes.float64]):
            with self.session(graph=ops.Graph()):
                # Initialize variables for numpy implementation.
                m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
                var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
                grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
                var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
                grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

                if use_resource:
                    var0 = resource_variable_ops.ResourceVariable(
                        var0_np, name="var0_%d" % i)
                    var1 = resource_variable_ops.ResourceVariable(
                        var1_np, name="var1_%d" % i)
                else:
                    var0 = variables.RefVariable(var0_np)
                    var1 = variables.RefVariable(var1_np)
                grads0 = constant_op.constant(grads0_np)
                grads1 = constant_op.constant(grads1_np)

                learning_rate = lambda: 0.001
                beta1 = lambda: 0.9
                beta2 = lambda: 0.999
                epsilon = lambda: 1e-8
                if not use_callable_params:
                    learning_rate = learning_rate()
                    beta1 = beta1()
                    beta2 = beta2()
                    epsilon = epsilon()

                opt = adam.AdamOptimizer(learning_rate=learning_rate)
                update = opt.apply_gradients(
                    zip([grads0, grads1], [var0, var1]))
                opt_variables = opt.variables()
                beta1_power, beta2_power = opt._get_beta_accumulators()
                self.assertTrue(beta1_power is not None)
                self.assertTrue(beta2_power is not None)
                self.assertIn(beta1_power, opt_variables)
                self.assertIn(beta2_power, opt_variables)
                # Ensure that non-slot variables are the same type as the requested
                # variables.
                self.assertEqual(
                    use_resource,
                    resource_variable_ops.is_resource_variable(beta1_power))
                self.assertEqual(
                    use_resource,
                    resource_variable_ops.is_resource_variable(beta2_power))

                if not context.executing_eagerly():
                    with ops.Graph().as_default():
                        # Shouldn't return non-slot variables from other graphs.
                        self.assertEqual(0, len(opt.variables()))
                    self.evaluate(variables.global_variables_initializer())
                    # Fetch params to validate initial values
                    self.assertAllClose([1.0, 2.0], self.evaluate(var0))
                    self.assertAllClose([3.0, 4.0], self.evaluate(var1))

                beta1_power, beta2_power = opt._get_beta_accumulators()

                # Run 3 steps of Adam
                for t in range(1, 4):
                    if not context.executing_eagerly():
                        self.evaluate(update)
                    elif t > 1:
                        opt.apply_gradients(zip([grads0, grads1],
                                                [var0, var1]))

                    self.assertAllCloseAccordingToType(
                        0.9**(t + 1), self.evaluate(beta1_power))
                    self.assertAllCloseAccordingToType(
                        0.999**(t + 1), self.evaluate(beta2_power))

                    var0_np, m0, v0 = adam_update_numpy(
                        var0_np, grads0_np, t, m0, v0)
                    var1_np, m1, v1 = adam_update_numpy(
                        var1_np, grads1_np, t, m1, v1)

                    # Validate updated params
                    self.assertAllCloseAccordingToType(var0_np,
                                                       self.evaluate(var0))
                    self.assertAllCloseAccordingToType(var1_np,
                                                       self.evaluate(var1))
                    if use_resource:
                        self.assertEqual("var0_%d/Adam:0" % (i, ),
                                         opt.get_slot(var=var0, name="m").name)
示例#22
0
    def gradient(self,
                 target,
                 sources,
                 output_gradients=None,
                 unconnected_gradients=UnconnectedGradients.NONE):
        """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
        if self._tape is None:
            raise RuntimeError(
                "GradientTape.gradient can only be called once on "
                "non-persistent tapes.")
        if self._recording:
            if not self._persistent:
                self._pop_tape()
            else:
                logging.log_first_n(
                    logging.WARN,
                    "Calling GradientTape.gradient on a persistent "
                    "tape inside it's context is significantly less "
                    "efficient than calling it outside the context (it "
                    "causes the gradient ops to be recorded on the "
                    "tape, leading to increased CPU and memory usage). "
                    "Only call GradientTape.gradient inside the "
                    "context if you actually want to trace the "
                    "gradient in order to compute higher order "
                    "derrivatives.", 1)

        flat_targets = nest.flatten(target)
        for t in flat_targets:
            if resource_variable_ops.is_resource_variable(t):
                raise ValueError(
                    "GradientTape.gradient is not supported for variable "
                    "targets.")

        flat_sources = nest.flatten(sources)
        flat_sources = [_handle_or_self(x) for x in flat_sources]

        if output_gradients is not None:
            output_gradients = [
                None if x is None else ops.convert_to_tensor(x)
                for x in nest.flatten(output_gradients)
            ]

        flat_grad = imperative_grad.imperative_grad(
            self._tape,
            flat_targets,
            flat_sources,
            output_gradients=output_gradients,
            unconnected_gradients=unconnected_gradients)

        if not self._persistent:
            self._tape = None

        grad = nest.pack_sequence_as(sources, flat_grad)
        return grad
示例#23
0
 def testIsResourceVariable(self):
     v = self.create_variable()
     self.assertTrue(resource_variable_ops.is_resource_variable(v))
示例#24
0
    def _setup_functions_captures(self):
        """Setup captures and variables in restored functions."""
        concrete_functions = sorted(self._proto.concrete_functions.items())
        for name, proto in concrete_functions:
            concrete_function = self._concrete_functions[name]
            bound_inputs = [
                self._get_tensor_from_node(node_id, name)
                for node_id in proto.bound_inputs
            ]
            bound_variables = [
                self._nodes[node_id] for node_id in proto.bound_inputs
                if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
            ]
            # TODO(andresp): This is only injecting the captured inputs into the
            # concrete function, note that we did not modify the FuncGraph
            # itself.
            captured_inputs_list = []
            concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
            if bound_inputs:
                for bound_input, internal_capture in zip(
                        bound_inputs,
                        concrete_function.inputs[-len(bound_inputs):]):
                    if distribute_utils.is_distributed_variable(bound_input):
                        concrete_function.graph.capture_distributed_variable(
                            bound_input, internal_capture)
                        captured_inputs_list.append(bound_input)
                    elif distribute_utils.is_distributed_table(bound_input):
                        closure, spec = bound_input.resource_handle_call_time_value(
                        )
                        concrete_function.graph.replace_capture_with_deferred_capture(
                            bound_input._coordinator_instance.resource_handle,  # pylint: disable=protected-access
                            closure,
                            spec,
                            default_value=bound_input._coordinator_instance.
                            resource_handle,  # pylint: disable=protected-access
                            placeholder=internal_capture)
                        captured_inputs_list.append(
                            concrete_function.graph.
                            deferred_external_captures[-1])

                    else:
                        captured_inputs_list.append(bound_input)
                        concrete_function.graph.replace_capture(
                            bound_input, internal_capture)
                        if internal_capture.dtype == dtypes.resource:
                            if resource_variable_ops.is_resource_variable(
                                    bound_input):
                                try:
                                    handle = bound_input.handle
                                except ValueError:
                                    # For mirrored variables we'll copy handle data for components
                                    # as they get captured.
                                    pass
                                else:
                                    handle_data_util.copy_handle_data(
                                        handle, internal_capture)
                            else:
                                handle_data_util.copy_handle_data(
                                    bound_input, internal_capture)
                        # Setting "captures" first means "capture" won't create a new
                        # placeholder for this input.
                        concrete_function.graph.capture(bound_input)

            concrete_function.set_external_captures(captured_inputs_list)
示例#25
0
  def gradient(self,
               target,
               sources,
               output_gradients=None,
               unconnected_gradients=UnconnectedGradients.NONE):
    """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
    if self._tape is None:
      raise RuntimeError("GradientTape.gradient can only be called once on "
                         "non-persistent tapes.")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(
            logging.WARN, "Calling GradientTape.gradient on a persistent "
            "tape inside its context is significantly less "
            "efficient than calling it outside the context (it "
            "causes the gradient ops to be recorded on the "
            "tape, leading to increased CPU and memory usage). "
            "Only call GradientTape.gradient inside the "
            "context if you actually want to trace the "
            "gradient in order to compute higher order "
            "derivatives.", 1)

    flat_targets = []
    for t in nest.flatten(target):
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the target tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)
      if resource_variable_ops.is_resource_variable(t):
        with self:
          t = ops.convert_to_tensor(t)
      flat_targets.append(t)

    flat_sources = nest.flatten(sources)
    flat_sources_raw = flat_sources
    flat_sources = [_handle_or_self(x) for x in flat_sources]
    for t in flat_sources_raw:
      if not t.dtype.is_floating:
        logging.vlog(
            logging.WARN, "The dtype of the source tensor must be "
            "floating (e.g. tf.float32) when calling GradientTape.gradient, "
            "got %r", t.dtype)

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients,
        sources_raw=flat_sources_raw,
        unconnected_gradients=unconnected_gradients)

    if not self._persistent:
      self._tape = None

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
示例#26
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()
    args = [ops.convert_to_tensor(x) for x in args]

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.experimental_ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with backprop.GradientTape() as tape:
        result, grad_fn = f(*args)
    after_vars = set([
        v.experimental_ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")
    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    inputs = args
    variables_in_tape = frozenset(
        [v.experimental_ref()
         for v in tape.watched_variables()]) - frozenset(v.experimental_ref()
                                                         for v in inputs)
    variables_in_subgraph = frozenset([
        v.experimental_ref()
        for v in get_dependent_variables(input_ops=inputs, output_ops=result)
    ])
    variables = list(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError("If using @custom_gradient with a function that "
                        "uses variables, then grad_fn must accept a keyword "
                        "argument 'variables'.")
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        if not variable_scope.get_variable_scope().use_resource:
            raise TypeError(
                "If using @custom_gradient with a function that "
                "uses variables, the enclosing variable scope must "
                "have use_resource=True.")
        else:
            logging.warn(
                "@custom_gradient grad_fn has 'variables' in signature, but "
                "no ResourceVariables were used on the forward pass.")
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
示例#27
0
def _graph_mode_decorator(f, args, kwargs):
    """Implement custom gradient decorator for graph mode."""
    # TODO(rsepassi): Add support for kwargs
    if kwargs:
        raise ValueError(
            "The custom_gradient decorator currently supports keywords "
            "arguments only when eager execution is enabled.")
    name = "CustomGradient-%s" % ops.uid()
    args = nest.map_structure(ops.convert_to_tensor, args)

    # Checking global and local variables attempts to ensure that no non-resource
    # Variables are added to the graph.
    current_var_scope = variable_scope.get_variable_scope()
    before_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    with tape_lib.VariableWatcher() as variable_watcher:
        result, grad_fn = f(*args)

    args = nest.flatten(args)
    flat_result = nest.flatten(result)
    flat_result_len = len(flat_result)

    after_vars = set([
        v.ref() for v in current_var_scope.global_variables() +
        current_var_scope.local_variables()
    ])
    new_vars = after_vars - before_vars
    new_vars_list = [v.deref() for v in new_vars]
    for v in new_vars_list:
        if not resource_variable_ops.is_resource_variable(v):
            raise TypeError(
                "All variables used by a function wrapped with @custom_gradient must "
                "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
                "with `use_resource=False`.")

    # The variables that grad_fn needs to return gradients for are the set of
    # variables used that are *not* part of the inputs.
    variables_in_tape = frozenset(
        [v.ref() for v in variable_watcher.watched_variables()])

    graphs = {getattr(o, "graph", None) for o in flat_result}
    # Not all results may be tensors. However, we want to ensure all tensor
    # outputs are from the same graph and get a list of captured inputs for
    # variable search
    graphs.discard(None)  # Discard non-graph outputs
    if graphs:
        if len(graphs) > 1:
            raise ValueError(
                "All custom_gradient outputs should be from the same graph")
        output_graph = graphs.pop()
        filtered_input_tensors = []
        for i in args:
            if i.graph == output_graph:
                filtered_input_tensors.append(i)
    else:
        filtered_input_tensors = args

    variables_in_subgraph = frozenset([
        v.ref()
        for v in _get_dependent_variables(input_ops=filtered_input_tensors,
                                          output_ops=flat_result)
    ])
    variables = sorted(
        [v.deref() for v in variables_in_subgraph.union(variables_in_tape)],
        key=lambda v: v.name)

    grad_argspec = tf_inspect.getfullargspec(grad_fn)
    variables_in_signature = ("variables" in grad_argspec.args
                              or "variables" in grad_argspec.kwonlyargs
                              or grad_argspec.varkw)
    if variables and not variables_in_signature:
        raise TypeError(
            "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
            "since function uses variables: {}".format(variables))
    if variables_in_signature and not variables:
        # User seems to intend to use variables but none were captured.
        logging.warn(
            "@custom_gradient grad_fn has 'variables' in signature, but "
            "no ResourceVariables were used on the forward pass.")

    all_tensors = flat_result + args + variables

    def tape_grad_fn(*result_grads):
        """Custom grad fn wrapper."""
        result_grads = result_grads[:flat_result_len]
        if variables:
            input_grads, variable_grads = grad_fn(*result_grads,
                                                  variables=variables)
            if len(variable_grads) != len(variables):
                raise ValueError("Must return gradient for each variable from "
                                 "@custom_gradient grad_fn.")
        else:
            input_grads = grad_fn(*result_grads)
            variable_grads = []

        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        input_grads = nest.flatten(input_grads)
        return ([None] * flat_result_len) + input_grads + variable_grads

    @ops.RegisterGradient(name)
    def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        """Custom grad fn wrapper."""
        return tape_grad_fn(*result_grads)

    original_tensors = all_tensors
    with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)

    original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

    # Propagate handle data for happier shape inference for resource variables.
    for i, t in enumerate(original_tensors):
        if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
            all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
    tape_lib.record_operation(f.__name__, all_tensors, original_tensors,
                              tape_grad_fn)
    for ot, t in zip(original_tensors, all_tensors):
        copy_handle_data(ot, t)
    return nest.pack_sequence_as(structure=result,
                                 flat_sequence=all_tensors[:flat_result_len])
示例#28
0
  def doTestBasic(self, use_resource=False, use_callable_params=False):
    if context.executing_eagerly() and not use_resource:
      self.skipTest(
          "Skipping test with use_resource=False and executing eagerly.")
    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
      with self.session(graph=ops.Graph()):
        # Initialize variables for numpy implementation.
        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

        if use_resource:
          var0 = resource_variable_ops.ResourceVariable(
              var0_np, name="var0_%d" % i)
          var1 = resource_variable_ops.ResourceVariable(
              var1_np, name="var1_%d" % i)
        else:
          var0 = variables.RefVariable(var0_np)
          var1 = variables.RefVariable(var1_np)
        grads0 = constant_op.constant(grads0_np)
        grads1 = constant_op.constant(grads1_np)

        learning_rate = lambda: 0.001
        beta1 = lambda: 0.9
        beta2 = lambda: 0.999
        epsilon = lambda: 1e-8
        if not use_callable_params:
          learning_rate = learning_rate()
          beta1 = beta1()
          beta2 = beta2()
          epsilon = epsilon()

        opt = adam.AdamOptimizer(learning_rate=learning_rate)
        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
        opt_variables = opt.variables()
        beta1_power, beta2_power = opt._get_beta_accumulators()
        self.assertTrue(beta1_power is not None)
        self.assertTrue(beta2_power is not None)
        self.assertIn(beta1_power, opt_variables)
        self.assertIn(beta2_power, opt_variables)
        # Ensure that non-slot variables are the same type as the requested
        # variables.
        self.assertEqual(
            use_resource,
            resource_variable_ops.is_resource_variable(beta1_power))
        self.assertEqual(
            use_resource,
            resource_variable_ops.is_resource_variable(beta2_power))

        if not context.executing_eagerly():
          with ops.Graph().as_default():
            # Shouldn't return non-slot variables from other graphs.
            self.assertEqual(0, len(opt.variables()))
          self.evaluate(variables.global_variables_initializer())
          # Fetch params to validate initial values
          self.assertAllClose([1.0, 2.0], self.evaluate(var0))
          self.assertAllClose([3.0, 4.0], self.evaluate(var1))

        beta1_power, beta2_power = opt._get_beta_accumulators()

        # Run 3 steps of Adam
        for t in range(1, 4):
          if not context.executing_eagerly():
            self.evaluate(update)
          elif t > 1:
            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))

          self.assertAllCloseAccordingToType(0.9**(t + 1),
                                             self.evaluate(beta1_power))
          self.assertAllCloseAccordingToType(0.999**(t + 1),
                                             self.evaluate(beta2_power))

          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)

          # Validate updated params
          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
          if use_resource:
            self.assertEqual("var0_%d/Adam:0" % (i,),
                             opt.get_slot(var=var0, name="m").name)
示例#29
0
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
                     gate_gradients, aggregation_method, stop_gradients):
  """Implementation of gradients()."""
  if context.executing_eagerly():
    raise RuntimeError("tf.gradients not supported when eager execution "
                       "is enabled. Use tf.contrib.eager.GradientTape "
                       "instead.")
  ys = _AsList(ys)
  xs = _AsList(xs)
  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
  if grad_ys is None:
    grad_ys = [None] * len(ys)
  else:
    grad_ys = _AsList(grad_ys)

  with ops.name_scope(
      name, "gradients",
      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
    xs = [
        x.handle if resource_variable_ops.is_resource_variable(x) else x
        for x in xs
    ]
    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
        xs, name="x", as_ref=True)
    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

    # The approach we take here is as follows: Create a list of all ops in the
    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
    # to ensure that when we visit an op the gradients w.r.t its outputs have
    # been collected.  Then aggregate these gradients if needed, call the op's
    # gradient function, and add the generated gradients to the gradients for
    # its input.

    # Initialize the pending count for ops in the connected subgraph from ys
    # to the xs.
    if len(ys) > 1:
      ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
    to_ops = [t.op for t in ys]
    from_ops = [t.op for t in xs]
    stop_gradient_ops = [t.op for t in stop_gradients]
    pending_count, loop_state = _PendingCount(
        ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)

    # Iterate over the collected ops.
    #
    # grads: op => list of gradients received on each output endpoint of the
    # op.  The gradients for each endpoint are initially collected as a list.
    # When it is time to call the op's gradient function, for each endpoint we
    # aggregate the list of received gradients into a Add() Operation if there
    # is more than one.
    grads = {}

    # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

    # Initialize queue with to_ops.
    queue = collections.deque()
    # Add the ops in 'to_ops' into the queue.
    to_ops_set = set()
    for op in to_ops:
      # 'ready' handles the case where one output gradient relies on
      # another output's gradient.
      # pylint: disable=protected-access
      ready = (pending_count[op._id] == 0)
      if ready and op._id not in to_ops_set:
        to_ops_set.add(op._id)
        queue.append(op)
      # pylint: enable=protected-access

    if loop_state:
      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
      for y in loop_exits:
        if _IsTrainable(y):
          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
          queue.append(y.op)

    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
    while queue:
      # generate gradient subgraph for op.
      op = queue.popleft()
      with _maybe_colocate_with(op, colocate_gradients_with_ops):
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=True)
        out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=True)

        grad_fn = None
        # pylint: disable=protected-access
        func_call = None
        is_func_call = ops.get_default_graph()._is_function(op.type)
        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
        if has_out_grads and (op._id not in stop_ops):
          if is_func_call:
            func_call = ops.get_default_graph()._get_function(op.type)
            grad_fn = func_call.python_grad_func
            # pylint: enable=protected-access
          else:
            # A grad_fn must be defined, either as a function or as None
            # for ops that do not have gradients.
            try:
              grad_fn = ops.get_gradient_function(op)
            except LookupError:
              raise LookupError(
                  "No gradient defined for operation '%s' (op type: %s)" %
                  (op.name, op.type))
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=False)
        if (grad_fn or is_func_call) and has_out_grads:
          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
          # output, it means that the cost does not depend on output[i],
          # therefore dC/doutput[i] is 0.
          for i, out_grad in enumerate(out_grads):
            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
                (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
              # Only trainable outputs or outputs for a function call that
              # will use SymbolicGradient get a zero gradient. Gradient
              # functions should ignore the gradient for other outputs.
              # TODO(apassos) gradients of resource handles might be an
              # issue here because of zeros.
              if loop_state:
                out_grads[i] = loop_state.ZerosLike(op, i)
              else:
                out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
          with ops.name_scope(op.name + "_grad"):
            # pylint: disable=protected-access
            with ops.get_default_graph()._original_op(op):
              # pylint: enable=protected-access
              if grad_fn:
                # If grad_fn was found, do not use SymbolicGradient even for
                # functions.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: grad_fn(op, *out_grads))
              else:
                # For function call ops, we add a 'SymbolicGradient'
                # node to the graph to compute gradients.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: _SymGrad(op, out_grads))
              in_grads = _AsList(in_grads)
              _VerifyGeneratedGradients(in_grads, op)
              if gate_gradients and len([x for x in in_grads
                                         if x is not None]) > 1:
                with ops.device(None):
                  with ops.colocate_with(None, ignore_existing=True):
                    in_grads = control_flow_ops.tuple(in_grads)
          _LogOpGradients(op, out_grads, in_grads)
        else:
          # If no grad_fn is defined or none of out_grads is available,
          # just propagate a list of None backwards.
          in_grads = [None] * len(op.inputs)
        for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
          if in_grad is not None:
            if (isinstance(in_grad, ops.Tensor) and
                t_in.dtype != dtypes.resource):
              try:
                in_grad.set_shape(t_in.get_shape())
              except ValueError:
                raise ValueError(
                    "Incompatible shapes between op input and calculated "
                    "input gradient.  Forward operation: %s.  Input index: %d. "
                    "Original input shape: %s.  "
                    "Calculated input gradient shape: %s" %
                    (op.name, i, t_in.shape, in_grad.shape))
            _SetGrad(grads, t_in, in_grad)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=False)

      # Update pending count for the inputs of op and enqueue ready ops.
      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)

  if loop_state:
    loop_state.PostProcessing()
  return [_GetGrad(grads, x) for x in xs]
示例#30
0
def _GradientsHelper(ys,
                     xs,
                     grad_ys=None,
                     name="gradients",
                     colocate_gradients_with_ops=False,
                     gate_gradients=False,
                     aggregation_method=None,
                     stop_gradients=None,
                     unconnected_gradients=UnconnectedGradients.NONE,
                     src_graph=None):
    """Implementation of gradients()."""
    if context.executing_eagerly():
        raise RuntimeError(
            "tf.gradients is not supported when eager execution "
            "is enabled. Use tf.GradientTape instead.")
    if src_graph is None:
        src_graph = ops.get_default_graph()
    try:
        unconnected_gradients = UnconnectedGradients(unconnected_gradients)
    except ValueError:
        raise ValueError("Unknown value for unconnected_gradients: %r" %
                         unconnected_gradients)

    # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
    # ancestor graphs. This is necessary for correctly handling captured values.
    func_graphs = []
    curr_graph = src_graph
    while _IsFunction(curr_graph):
        func_graphs.append(curr_graph)
        if isinstance(curr_graph, FuncGraph):
            curr_graph = curr_graph.outer_graph
        else:
            assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
            curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access

    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        # Get a uid for this call to gradients that can be used to help
        # cluster ops for compilation.
        gradient_uid = ops.get_default_graph().unique_name("uid")
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
                                 gradient_uid)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        reachable_to_ops, pending_count, loop_state = _PendingCount(
            to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            ready = (pending_count[op] == 0)
            if ready and op not in to_ops_set and op in reachable_to_ops:
                to_ops_set.add(op)
                queue.append(op)

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, gradient_uid,
                                      colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, gradient_uid,
                                             loop_state, aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                func_call = None
                is_partitioned_call = _IsPartitionedCall(op)
                # pylint: disable=protected-access
                is_func_call = (src_graph._is_function(op.type)
                                or is_partitioned_call)
                # pylint: enable=protected-access
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op not in stop_ops):
                    try:
                        grad_fn = ops.get_gradient_function(op)
                    except LookupError:
                        if is_func_call:
                            if is_partitioned_call:
                                func_call = src_graph._get_function(  # pylint: disable=protected-access
                                    compat.as_bytes(op.get_attr("f").name))
                            else:
                                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
                            # Note that __defun is not set if the graph is
                            # imported. If it's set, we prefer to access the original
                            # defun.
                            func_call = getattr(op, "__defun", func_call)
                            grad_fn = func_call.python_grad_func
                        else:
                            raise LookupError(
                                "No gradient defined for operation '%s' (op type: %s)"
                                % (op.name, op.type))
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)

                # NOTE(skyewm): We don't support computing gradients wrt a loop variable
                # unless it's within the context of a single iteration (i.e. the
                # gradient is wrt to the loop parameter in the body function, not wrt or
                # through the initial value). This means if we're in a while loop
                # context, we should never see a switch node from this context.
                # pylint: disable=protected-access
                if (control_flow_util.IsSwitch(op)
                        and op._control_flow_context is not None
                        and op._control_flow_context.IsWhileContext()
                        and op._control_flow_context ==
                        ops.get_default_graph()._get_control_flow_context()):
                    _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
                # pylint: enable=protected-access

                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call)
                                    or IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            else:
                                out_grads[
                                    i] = control_flow_ops.ZerosLikeOutsideLoop(
                                        op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with src_graph._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
                                            None,
                                            gradient_uid,
                                            ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(_NonEagerInputs(op, xs))
                for i, (t_in, in_grad) in enumerate(
                        zip(_NonEagerInputs(op, xs), in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    "input gradient.  Forward operation: %s.  Input index: %d. "
                                    "Original input shape: %s.  "
                                    "Calculated input gradient shape: %s" %
                                    (op.name, i, t_in.shape, in_grad.shape))
                        _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state, xs)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
示例#31
0
def _lift_unlifted_variables(graph, variable_holder):
    """Finds resource variables and lifts them into the outer context.

  When we import a GraphDef inside a wrap_function, no Python graph building
  code runs. This means we get VarHandleOps which create variable resources,
  but no corresponding Python objects. Leaving them like this works but gives
  the user no way to interact with or modify the variables outside the graph.

  This method searches for variables and lifts them out as regular variable
  objects when possible, indicating to the FuncGraph that they are captures.

  Args:
    graph: The FuncGraph to lift variables from.
    variable_holder: A VariableHolder to record the lifted variables in.
  """
    with graph.as_default():
        global_collection_variables = ops.get_collection(
            ops.GraphKeys.GLOBAL_VARIABLES)
        local_collection_variables = ops.get_collection(
            ops.GraphKeys.LOCAL_VARIABLES)
        existing_captures = {id(c) for c in graph.internal_captures}
        lifted_variables = {}

        def _should_lift_variable(v):
            return ((
                v._in_graph_mode  # pylint: disable=protected-access
                and v.graph.building_function) and isinstance(
                    v, resource_variable_ops.BaseResourceVariable)
                    and id(v.handle) not in existing_captures)

        for old_variable in global_collection_variables:
            if _should_lift_variable(old_variable):
                new_variable = _lift_single_variable(old_variable, graph,
                                                     variable_holder)
                lifted_variables[id(old_variable)] = new_variable
                existing_captures.add(id(old_variable.handle))

        for old_variable in local_collection_variables:
            if _should_lift_variable(old_variable):
                new_variable = _lift_single_variable(old_variable, graph,
                                                     variable_holder)
                lifted_variables[id(old_variable)] = new_variable
                existing_captures.add(id(old_variable.handle))
                if new_variable._in_graph_mode:  # pylint: disable=protected-access
                    outer_graph = new_variable.graph
                    # Variables are added to the global collection by default. In this
                    # case we only want the variable in the local collection, so we'll pop
                    # it out.
                    global_collection = outer_graph.get_collection_ref(
                        ops.GraphKeys.GLOBAL_VARIABLES)
                    global_collection.remove(new_variable)
                    outer_graph.add_to_collection(
                        ops.GraphKeys.LOCAL_VARIABLES, new_variable)

        # Update the FuncGraph's collections, partly for the user and partly so this
        # function is idempotent when it runs again in prune() calls.
        for collection_name in [
                ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
        ]:
            mutable_collection = ops.get_collection_ref(collection_name)
            for index, current in enumerate(mutable_collection):
                mutable_collection[index] = lifted_variables.get(
                    id(current), current)
                if not resource_variable_ops.is_resource_variable(
                        mutable_collection[index]):
                    logging.warning(
                        "Unable to create a python object for variable {} because it is "
                        "a reference variable. It may not be visible to training APIs. "
                        "If this is a problem, consider rebuilding the SavedModel after "
                        "running tf.compat.v1.enable_resource_variables().".
                        format(mutable_collection[index]))
示例#32
0
文件: utils.py 项目: lucaslingle/kfac
def is_reference_variable(x):
    return ((isinstance(x, tf.Variable)
             and not resource_variable_ops.is_resource_variable(x))
            or hasattr(x, "_should_act_as_ref_variable"))
示例#33
0
def _GradientsHelper(ys,
                     xs,
                     grad_ys=None,
                     name="gradients",
                     colocate_gradients_with_ops=False,
                     gate_gradients=False,
                     aggregation_method=None,
                     stop_gradients=None,
                     unconnected_gradients=UnconnectedGradients.NONE,
                     src_graph=None):
  """Implementation of gradients()."""
  if context.executing_eagerly():
    raise RuntimeError("tf.gradients is not supported when eager execution "
                       "is enabled. Use tf.GradientTape instead.")
  if src_graph is None:
    src_graph = ops.get_default_graph()
  try:
    unconnected_gradients = UnconnectedGradients(unconnected_gradients)
  except ValueError:
    raise ValueError(
        "Unknown value for unconnected_gradients: %r" % unconnected_gradients)

  # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
  # ancestor graphs. This is necessary for correctly handling captured values.
  func_graphs = []
  curr_graph = src_graph
  while _IsFunction(curr_graph):
    func_graphs.append(curr_graph)
    if isinstance(curr_graph, FuncGraph):
      curr_graph = curr_graph.outer_graph
    else:
      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access

  ys = _AsList(ys)
  xs = _AsList(xs)
  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
  if grad_ys is None:
    grad_ys = [None] * len(ys)
  else:
    grad_ys = _AsList(grad_ys)

  with ops.name_scope(
      name, "gradients",
      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
    # Get a uid for this call to gradients that can be used to help
    # cluster ops for compilation.
    gradient_uid = ops.get_default_graph().unique_name("uid")
    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
    xs = [
        x.handle if resource_variable_ops.is_resource_variable(x) else x
        for x in xs
    ]
    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
        xs, name="x", as_ref=True)
    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
                             gradient_uid)

    # The approach we take here is as follows: Create a list of all ops in the
    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
    # to ensure that when we visit an op the gradients w.r.t its outputs have
    # been collected.  Then aggregate these gradients if needed, call the op's
    # gradient function, and add the generated gradients to the gradients for
    # its input.

    # Initialize the pending count for ops in the connected subgraph from ys
    # to the xs.
    to_ops = [t.op for t in ys]
    from_ops = [t.op for t in xs]
    stop_gradient_ops = [t.op for t in stop_gradients]
    reachable_to_ops, pending_count, loop_state = _PendingCount(
        to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)

    # Iterate over the collected ops.
    #
    # grads: op => list of gradients received on each output endpoint of the
    # op.  The gradients for each endpoint are initially collected as a list.
    # When it is time to call the op's gradient function, for each endpoint we
    # aggregate the list of received gradients into a Add() Operation if there
    # is more than one.
    grads = {}

    # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

    # Initialize queue with to_ops.
    queue = collections.deque()
    # Add the ops in 'to_ops' into the queue.
    to_ops_set = set()
    for op in to_ops:
      # 'ready' handles the case where one output gradient relies on
      # another output's gradient.
      ready = (pending_count[op] == 0)
      if ready and op not in to_ops_set and op in reachable_to_ops:
        to_ops_set.add(op)
        queue.append(op)

    if loop_state:
      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
      for y in loop_exits:
        if IsTrainable(y):
          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
          queue.append(y.op)

    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
    while queue:
      # generate gradient subgraph for op.
      op = queue.popleft()
      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=True)
        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
                                     aggregation_method)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=True)

        grad_fn = None
        func_call = None
        is_partitioned_call = _IsPartitionedCall(op)
        # pylint: disable=protected-access
        is_func_call = (
            src_graph._is_function(op.type) or is_partitioned_call)
        # pylint: enable=protected-access
        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
        if has_out_grads and (op not in stop_ops):
          try:
            grad_fn = ops.get_gradient_function(op)
          except LookupError:
            if is_func_call:
              if is_partitioned_call:
                func_call = src_graph._get_function(  # pylint: disable=protected-access
                    compat.as_bytes(op.get_attr("f").name))
              else:
                func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
              # Note that __defun is not set if the graph is
              # imported. If it's set, we prefer to access the original
              # defun.
              func_call = getattr(op, "__defun", func_call)
              grad_fn = func_call.python_grad_func
            else:
              raise LookupError(
                  "No gradient defined for operation '%s' (op type: %s)" %
                  (op.name, op.type))
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=False)

        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
        # unless it's within the context of a single iteration (i.e. the
        # gradient is wrt to the loop parameter in the body function, not wrt or
        # through the initial value). This means if we're in a while loop
        # context, we should never see a switch node from this context.
        # pylint: disable=protected-access
        if (control_flow_util.IsSwitch(op) and
            op._control_flow_context is not None and
            op._control_flow_context.IsWhileContext() and
            op._control_flow_context ==
            ops.get_default_graph()._get_control_flow_context()):
          _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
        # pylint: enable=protected-access

        if (grad_fn or is_func_call) and has_out_grads:
          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
          # output, it means that the cost does not depend on output[i],
          # therefore dC/doutput[i] is 0.
          for i, out_grad in enumerate(out_grads):
            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
                (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
              # Only trainable outputs or outputs for a function call that
              # will use SymbolicGradient get a zero gradient. Gradient
              # functions should ignore the gradient for other outputs.
              # TODO(apassos) gradients of resource handles might be an
              # issue here because of zeros.
              if loop_state:
                out_grads[i] = loop_state.ZerosLike(op, i)
              else:
                out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
          with ops.name_scope(op.name + "_grad"):
            # pylint: disable=protected-access
            with src_graph._original_op(op):
              # pylint: enable=protected-access
              if grad_fn:
                # If grad_fn was found, do not use SymbolicGradient even for
                # functions.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: grad_fn(op, *out_grads))
              else:
                # For function call ops, we add a 'SymbolicGradient'
                # node to the graph to compute gradients.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: _SymGrad(op, out_grads))
              in_grads = _AsList(in_grads)
              _VerifyGeneratedGradients(in_grads, op)
              if gate_gradients and len([x for x in in_grads
                                         if x is not None]) > 1:
                with ops.device(None):
                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
                      None,
                      gradient_uid,
                      ignore_existing=True):
                    in_grads = control_flow_ops.tuple(in_grads)
          _LogOpGradients(op, out_grads, in_grads)
        else:
          # If no grad_fn is defined or none of out_grads is available,
          # just propagate a list of None backwards.
          in_grads = [None] * len(_NonEagerInputs(op, xs))
        for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
                                                in_grads)):
          if in_grad is not None:
            if (isinstance(in_grad, ops.Tensor) and
                t_in.dtype != dtypes.resource):
              try:
                in_grad.set_shape(t_in.get_shape())
              except ValueError:
                raise ValueError(
                    "Incompatible shapes between op input and calculated "
                    "input gradient.  Forward operation: %s.  Input index: %d. "
                    "Original input shape: %s.  "
                    "Calculated input gradient shape: %s" %
                    (op.name, i, t_in.shape, in_grad.shape))
            _SetGrad(grads, t_in, in_grad)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=False)

      # Update pending count for the inputs of op and enqueue ready ops.
      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
                                    xs)

  if loop_state:
    loop_state.PostProcessing()
  return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
示例#34
0
 def var_to_tensor(var):
     if resource_variable_ops.is_resource_variable(var):
         return var.handle
     if utils.is_reference_variable(var):
         return tf_ops.internal_convert_to_tensor(var, as_ref=True)
     raise ValueError('%s is not a recognized variable type.' % str(var))
示例#35
0
def gradients(ys,
              xs,
              grad_ys=None,
              name="gradients",
              colocate_gradients_with_ops=False,
              gate_gradients=False,
              aggregation_method=None,
              stop_gradients=None):
    """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.

  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
  is a list of `Tensor`, holding the gradients received by the
  `ys`. The list must be the same length as `ys`.

  `gradients()` adds ops to the graph to output the derivatives of `ys` with
  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
  each tensor is the `sum(dy/dx)` for y in `ys`.

  `grad_ys` is a list of tensors of the same length as `ys` that holds
  the initial gradients for each y in `ys`.  When `grad_ys` is None,
  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
  user can provide their own initial `grad_ys` to compute the
  derivatives using a different initial gradient for each y (e.g., if
  one wanted to weight the gradient differently for each value in
  each y).

  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
  with respect to all `xs`. These tensors will not be backpropagated through,
  as though they had been explicitly disconnected using `stop_gradient`.  Among
  other things, this allows computation of partial derivatives as opposed to
  total derivatives. For example:

  ```python
  a = tf.constant(0.)
  b = 2 * a
  g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
  ```

  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
  equivalent to:

  ```python
  a = tf.stop_gradient(tf.constant(0.))
  b = tf.stop_gradient(2 * a)
  g = tf.gradients(a + b, [a, b])
  ```

  `stop_gradients` provides a way of stopping gradient after the graph has
  already been constructed, as compared to `tf.stop_gradient` which is used
  during graph construction.  When the two approaches are combined,
  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
  `stop_gradients`, whichever is encountered first.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grad_ys: Optional. A `Tensor` or list of tensors the same size as
      `ys` and holding the gradients computed for each y in `ys`.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'gradients'.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.
    gate_gradients: If True, add a tuple around the gradients returned
      for an operations.  This avoids some race conditions.
    aggregation_method: Specifies the method used to combine gradient terms.
      Accepted values are constants defined in the class `AggregationMethod`.
    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
      through.

  Returns:
    A list of `sum(dy/dx)` for each x in `xs`.

  Raises:
    LookupError: if one of the operations between `x` and `y` does not
      have a registered gradient function.
    ValueError: if the arguments are invalid.
    RuntimeError: if called in Eager mode.

  """
    if context.in_eager_mode():
        raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
                           "functions in tf.contrib.eager.backprop instead.")
    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        if len(ys) > 1:
            ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        pending_count, loop_state = _PendingCount(ops.get_default_graph(),
                                                  to_ops, from_ops,
                                                  colocate_gradients_with_ops)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            # pylint: disable=protected-access
            ready = (pending_count[op._id] == 0)
            if ready and op._id not in to_ops_set:
                to_ops_set.add(op._id)
                queue.append(op)
            # pylint: enable=protected-access

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if _IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, loop_state,
                                             aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                # pylint: disable=protected-access
                func_call = None
                is_func_call = ops.get_default_graph()._is_function(op.type)
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op._id not in stop_ops):
                    if is_func_call:
                        func_call = ops.get_default_graph()._get_function(
                            op.type)
                        grad_fn = func_call.python_grad_func
                        # pylint: enable=protected-access
                    else:
                        # A grad_fn must be defined, either as a function or as None
                        # for ops that do not have gradients.
                        try:
                            grad_fn = ops.get_gradient_function(op)
                        except LookupError:
                            raise LookupError(
                                "No gradient defined for operation '%s' (op type: %s)"
                                % (op.name, op.type))
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)
                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call)
                                    or _IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            else:
                                out_grads[
                                    i] = control_flow_ops.ZerosLikeOutsideLoop(
                                        op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with ops.get_default_graph()._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops.colocate_with(
                                            None, ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(op.inputs)
                for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    "input gradient.  Forward operation: %s.  Input index: %d. "
                                    "Original input shape: %s.  "
                                    "Calculated input gradient shape: %s" %
                                    (op.name, i, t_in.shape, in_grad.shape))
                        _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x) for x in xs]
def _is_variable(x):
    return (isinstance(x, variables.Variable)
            or resource_variable_ops.is_resource_variable(x))
def restore_captures(concrete_function, inputs):
    """Restore captures for the concrete function.

  Used at deserialization time.  For functions that are being deserialized,
  saved model restores objects that tensors were captured from, but functions
  only know about their tensors -- object information is destroyed by tracing.
  This additional logic extracts the tensors which the function originally
  captured.

  Args:
    concrete_function: the concrete function for which to restore captures
    inputs: a list tensors or other Python objects (such as variables) which
      contain tensors that were originally captured by the function
  """
    bound_inputs = [get_tensor_from_node(obj) for obj in inputs]
    bound_variables = [
        obj for obj in inputs
        if isinstance(obj, (variables_lib.Variable,
                            resource_variable_ops.BaseResourceVariable))
    ]
    # TODO(b/205010575): This is only injecting the captured inputs into the
    # concrete function, note that we did not modify the FuncGraph
    # itself.
    captured_inputs_list = []
    concrete_function.set_variables(bound_variables)
    if bound_inputs:
        for bound_input, internal_capture in zip(
                bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
            # Distributed inputs have special logic for capturing, so we call their
            # custom restoration methods
            if hasattr(bound_input, "__tf_experimental_restore_capture__"):
                captured_inputs_list.append(
                    bound_input.__tf_experimental_restore_capture__(
                        concrete_function, internal_capture))
            else:
                captured_inputs_list.append(bound_input)
                concrete_function.graph.replace_capture(
                    bound_input, internal_capture)
                if internal_capture.dtype == dtypes.resource:
                    if resource_variable_ops.is_resource_variable(bound_input):
                        try:
                            handle = bound_input.handle
                        except ValueError:
                            # For mirrored variables we'll copy handle data for components
                            # as they get captured.
                            pass
                        else:
                            handle_data_util.copy_handle_data(
                                handle, internal_capture)
                    else:
                        # TODO(b/213451747): Remove need to call copy_handle_data
                        handle_data_util.copy_handle_data(
                            bound_input, internal_capture)
                # Setting "captures" first means "capture" won't create a new
                # placeholder for this input.
                concrete_function.graph.capture(bound_input)

    if any([inp is None for inp in captured_inputs_list]):
        warnings.warn(
            "Trying to load ShardedVariables using tf.saved_model.load. "
            "This won't work if using a tf.distribute.Strategy, and may "
            "use excess memory if not using a Strategy. Ignore this "
            "warning if using tf.keras.models.load_model.")
    concrete_function.set_external_captures(captured_inputs_list)
示例#38
0
  def gradient(self, target, sources, output_gradients=None):
    """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: Tensor (or list of tensors) to be differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if called on variable target.
    """
    if self._tape is None:
      raise RuntimeError("GradientTape.gradient can only be called once on "
                         "non-persistent tapes.")
    if self._recording:
      if not self._persistent:
        self._pop_tape()
      else:
        logging.log_first_n(logging.WARN,
                            "Calling GradientTape.gradient on a persistent "
                            "tape inside it's context is significantly less "
                            "efficient than calling it outside the context (it "
                            "causes the gradient ops to be recorded on the "
                            "tape, leading to increased CPU and memory usage). "
                            "Only call GradientTape.gradient inside the "
                            "context if you actually want to trace the "
                            "gradient in order to compute higher order "
                            "derrivatives.", 1)

    flat_targets = nest.flatten(target)
    for t in flat_targets:
      if resource_variable_ops.is_resource_variable(t):
        raise ValueError("GradientTape.gradient is not supported for variable "
                         "targets.")

    flat_sources = nest.flatten(sources)
    flat_sources = [_handle_or_self(x) for x in flat_sources]

    if output_gradients is not None:
      output_gradients = [None if x is None else ops.convert_to_tensor(x)
                          for x in nest.flatten(output_gradients)]

    flat_grad = imperative_grad.imperative_grad(
        self._tape,
        flat_targets,
        flat_sources,
        output_gradients=output_gradients)

    if not self._persistent:
      self._tape = None

    grad = nest.pack_sequence_as(sources, flat_grad)
    return grad
示例#39
0
def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)

  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])
示例#40
0
def gradients(ys,
              xs,
              grad_ys=None,
              name="gradients",
              colocate_gradients_with_ops=False,
              gate_gradients=False,
              aggregation_method=None,
              stop_gradients=None):
  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.

  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
  is a list of `Tensor`, holding the gradients received by the
  `ys`. The list must be the same length as `ys`.

  `gradients()` adds ops to the graph to output the derivatives of `ys` with
  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
  each tensor is the `sum(dy/dx)` for y in `ys`.

  `grad_ys` is a list of tensors of the same length as `ys` that holds
  the initial gradients for each y in `ys`.  When `grad_ys` is None,
  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
  user can provide their own initial `grad_ys` to compute the
  derivatives using a different initial gradient for each y (e.g., if
  one wanted to weight the gradient differently for each value in
  each y).

  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
  with respect to all `xs`. These tensors will not be backpropagated through,
  as though they had been explicitly disconnected using `stop_gradient`.  Among
  other things, this allows computation of partial derivatives as opposed to
  total derivatives. For example:

  ```python
  a = tf.constant(0.)
  b = 2 * a
  g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
  ```

  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
  equivalent to:

  ```python
  a = tf.stop_gradient(tf.constant(0.))
  b = tf.stop_gradient(2 * a)
  g = tf.gradients(a + b, [a, b])
  ```

  `stop_gradients` provides a way of stopping gradient after the graph has
  already been constructed, as compared to `tf.stop_gradient` which is used
  during graph construction.  When the two approaches are combined,
  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
  `stop_gradients`, whichever is encountered first.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grad_ys: Optional. A `Tensor` or list of tensors the same size as
      `ys` and holding the gradients computed for each y in `ys`.
    name: Optional name to use for grouping all the gradient ops together.
      defaults to 'gradients'.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.
    gate_gradients: If True, add a tuple around the gradients returned
      for an operations.  This avoids some race conditions.
    aggregation_method: Specifies the method used to combine gradient terms.
      Accepted values are constants defined in the class `AggregationMethod`.
    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
      through.

  Returns:
    A list of `sum(dy/dx)` for each x in `xs`.

  Raises:
    LookupError: if one of the operations between `x` and `y` does not
      have a registered gradient function.
    ValueError: if the arguments are invalid.
    RuntimeError: if called in Eager mode.

  """
  if context.in_eager_mode():
    raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
                       "functions in tf.contrib.eager.backprop instead.")
  ys = _AsList(ys)
  xs = _AsList(xs)
  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
  if grad_ys is None:
    grad_ys = [None] * len(ys)
  else:
    grad_ys = _AsList(grad_ys)

  with ops.name_scope(
      name, "gradients",
      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
    ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
    xs = [
        x.handle if resource_variable_ops.is_resource_variable(x) else x
        for x in xs
    ]
    xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
        xs, name="x", as_ref=True)
    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

    # The approach we take here is as follows: Create a list of all ops in the
    # subgraph between the ys and xs.  Visit these ops in reverse order of ids
    # to ensure that when we visit an op the gradients w.r.t its outputs have
    # been collected.  Then aggregate these gradients if needed, call the op's
    # gradient function, and add the generated gradients to the gradients for
    # its input.

    # Initialize the pending count for ops in the connected subgraph from ys
    # to the xs.
    if len(ys) > 1:
      ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
    to_ops = [t.op for t in ys]
    from_ops = [t.op for t in xs]
    stop_gradient_ops = [t.op for t in stop_gradients]
    pending_count, loop_state = _PendingCount(
        ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)

    # Iterate over the collected ops.
    #
    # grads: op => list of gradients received on each output endpoint of the
    # op.  The gradients for each endpoint are initially collected as a list.
    # When it is time to call the op's gradient function, for each endpoint we
    # aggregate the list of received gradients into a Add() Operation if there
    # is more than one.
    grads = {}

    # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

    # Initialize queue with to_ops.
    queue = collections.deque()
    # Add the ops in 'to_ops' into the queue.
    to_ops_set = set()
    for op in to_ops:
      # 'ready' handles the case where one output gradient relies on
      # another output's gradient.
      # pylint: disable=protected-access
      ready = (pending_count[op._id] == 0)
      if ready and op._id not in to_ops_set:
        to_ops_set.add(op._id)
        queue.append(op)
      # pylint: enable=protected-access

    if loop_state:
      loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
      for y in loop_exits:
        if _IsTrainable(y):
          _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
          queue.append(y.op)

    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
    while queue:
      # generate gradient subgraph for op.
      op = queue.popleft()
      with _maybe_colocate_with(op, colocate_gradients_with_ops):
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=True)
        out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=True)

        grad_fn = None
        # pylint: disable=protected-access
        func_call = None
        is_func_call = ops.get_default_graph()._is_function(op.type)
        has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
        if has_out_grads and (op._id not in stop_ops):
          if is_func_call:
            func_call = ops.get_default_graph()._get_function(op.type)
            grad_fn = func_call.python_grad_func
            # pylint: enable=protected-access
          else:
            # A grad_fn must be defined, either as a function or as None
            # for ops that do not have gradients.
            try:
              grad_fn = ops.get_gradient_function(op)
            except LookupError:
              raise LookupError(
                  "No gradient defined for operation '%s' (op type: %s)" %
                  (op.name, op.type))
        if loop_state:
          loop_state.EnterGradWhileContext(op, before=False)
        if (grad_fn or is_func_call) and has_out_grads:
          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
          # output, it means that the cost does not depend on output[i],
          # therefore dC/doutput[i] is 0.
          for i, out_grad in enumerate(out_grads):
            if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
                (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
              # Only trainable outputs or outputs for a function call that
              # will use SymbolicGradient get a zero gradient. Gradient
              # functions should ignore the gradient for other outputs.
              # TODO(apassos) gradients of resource handles might be an
              # issue here because of zeros.
              if loop_state:
                out_grads[i] = loop_state.ZerosLike(op, i)
              else:
                out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
          with ops.name_scope(op.name + "_grad"):
            # pylint: disable=protected-access
            with ops.get_default_graph()._original_op(op):
              # pylint: enable=protected-access
              if grad_fn:
                # If grad_fn was found, do not use SymbolicGradient even for
                # functions.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: grad_fn(op, *out_grads))
              else:
                # For function call ops, we add a 'SymbolicGradient'
                # node to the graph to compute gradients.
                in_grads = _MaybeCompile(grad_scope, op, func_call,
                                         lambda: _SymGrad(op, out_grads))
              in_grads = _AsList(in_grads)
              _VerifyGeneratedGradients(in_grads, op)
              if gate_gradients and len([x for x in in_grads
                                         if x is not None]) > 1:
                with ops.device(None):
                  with ops.colocate_with(None, ignore_existing=True):
                    in_grads = control_flow_ops.tuple(in_grads)
          _LogOpGradients(op, out_grads, in_grads)
        else:
          # If no grad_fn is defined or none of out_grads is available,
          # just propagate a list of None backwards.
          in_grads = [None] * len(op.inputs)
        for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
          if in_grad is not None:
            if (isinstance(in_grad, ops.Tensor) and
                t_in.dtype != dtypes.resource):
              try:
                in_grad.set_shape(t_in.get_shape())
              except ValueError:
                raise ValueError(
                    "Incompatible shapes between op input and calculated "
                    "input gradient.  Forward operation: %s.  Input index: %d. "
                    "Original input shape: %s.  "
                    "Calculated input gradient shape: %s" %
                    (op.name, i, t_in.shape, in_grad.shape))
            _SetGrad(grads, t_in, in_grad)
        if loop_state:
          loop_state.ExitGradWhileContext(op, before=False)

      # Update pending count for the inputs of op and enqueue ready ops.
      _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)

  if loop_state:
    loop_state.PostProcessing()
  return [_GetGrad(grads, x) for x in xs]
示例#41
0
def _is_variable_like(obj):
    return (isinstance(obj, variables.Variable)
            or resource_variable_ops.is_resource_variable(obj))
示例#42
0
def _is_variable(x):
  return (isinstance(x, variables.Variable) or
          resource_variable_ops.is_resource_variable(x))
示例#43
0
def _handle_or_self(x):
  """If x is ResourceVariable, return its handle, else x."""
  if resource_variable_ops.is_resource_variable(x):
    x = x.handle
  return x
示例#44
0
  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})

    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.CapturableResource):
        # pylint: disable=protected-access
        with ops.device(obj._resource_device):
          new_resource = obj._create_resource()
        # pylint: enable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif ds_values.is_distributed_variable(obj):
        # Put both the distributed variable and component variable handles in
        # `captured_tensor_node_ids`.
        # Also create a new distributed variable for `object_map` with newly
        # created component variables.
        new_vars = []
        for v in obj.values:
          new_variable = resource_variable_ops.copy_to_graph_uninitialized(v)
          object_map[v] = new_variable
          new_vars.append(new_variable)
          resource_map[v.handle] = new_variable.handle
          self.captured_tensor_node_ids[v.handle] = node_id
        object_map[obj] = obj._clone_with_new_values(new_vars)  # pylint: disable=protected-access
        self.captured_tensor_node_ids[obj] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.Asset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      if not concrete_function.graph.saveable:
        raise ValueError(
            ("Unable to save function {name} for the following reason(s):\n" +
             "\n".join(concrete_function.graph.saving_errors))
            .format(name=concrete_function.name))
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          capture_constant_value = tensor_util.constant_value(capture)
          if capture_constant_value is None:
            raise ValueError(
                ("Attempted to save a function {} which references a symbolic "
                 "Tensor {} that is not a simple constant. This is not "
                 "supported.").format(concrete_function.name, capture))
          copied_tensor = constant_op.constant(capture_constant_value)
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
示例#45
0
def _handle_or_self(x):
    """If x is ResourceVariable, return its handle, else x."""
    if resource_variable_ops.is_resource_variable(x):
        x = x.handle
    return x
示例#46
0
def _create_slot_var(primary,
                     val,
                     scope,
                     validate_shape,
                     shape,
                     dtype,
                     *,
                     copy_xla_sharding=False):
    """Helper function for creating a slot variable."""

    # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
    # scope.
    current_partitioner = variable_scope.get_variable_scope().partitioner
    variable_scope.get_variable_scope().set_partitioner(None)
    # When init from val instead of callable initializer, the shape is expected to
    # be None, not <unknown> or any fully defined shape.
    shape = shape if callable(val) else None
    if resource_variable_ops.is_resource_variable(primary):
        use_resource = True
    elif isinstance(primary, variables.RefVariable):
        use_resource = False
    else:
        use_resource = None
    slot = variable_scope.get_variable(scope,
                                       initializer=val,
                                       trainable=False,
                                       use_resource=use_resource,
                                       shape=shape,
                                       dtype=dtype,
                                       validate_shape=validate_shape)
    variable_scope.get_variable_scope().set_partitioner(current_partitioner)

    # pylint: disable=protected-access
    if isinstance(primary, variables.Variable) and primary._save_slice_info:
        # Primary is a partitioned variable, so we need to also indicate that
        # the slot is a partitioned variable.  Slots have the same partitioning
        # as their primaries.
        # For examples when using AdamOptimizer in linear model, slot.name
        # here can be "linear//weights/Adam:0", while primary.op.name is
        # "linear//weight". We want to get 'Adam' as real_slot_name, so we
        # remove "'linear//weight' + '/'" and ':0'.
        real_slot_name = slot.name[len(primary.op.name + "/"):-2]
        slice_info = primary._save_slice_info
        # support slot's shape not same as primary's shape
        # example: primary's shape = [10, 20, 30], slot's shape =
        # None, [], [10], [10, 20] or [10, 20, 30] is allowed
        # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
        # slot's shape = [], don't set slot's slice_info
        # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
        n = slot.shape.ndims
        if n is None or n > 0:
            slot._set_save_slice_info(
                variables.Variable.SaveSliceInfo(
                    slice_info.full_name + "/" + real_slot_name,
                    slice_info.full_shape[:n], slice_info.var_offset[:n],
                    slice_info.var_shape[:n]))
    # pylint: enable=protected-access

    # Copy XLA sharding attributes from primary.
    if copy_xla_sharding:
        slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
    return slot
示例#47
0
    def gradient(self,
                 target,
                 sources,
                 output_gradients=None,
                 unconnected_gradients=UnconnectedGradients.NONE):
        """Computes the gradient using operations recorded in context of this tape.

    Args:
      target: a list or nested structure of Tensors or Variables to be
        differentiated.
      sources: a list or nested structure of Tensors or Variables. `target`
        will be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of
        target. Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
        if self._tape is None:
            raise RuntimeError(
                "GradientTape.gradient can only be called once on "
                "non-persistent tapes.")
        if self._recording:
            if not self._persistent:
                self._pop_tape()
            else:
                logging.log_first_n(
                    logging.WARN,
                    "Calling GradientTape.gradient on a persistent "
                    "tape inside its context is significantly less "
                    "efficient than calling it outside the context (it "
                    "causes the gradient ops to be recorded on the "
                    "tape, leading to increased CPU and memory usage). "
                    "Only call GradientTape.gradient inside the "
                    "context if you actually want to trace the "
                    "gradient in order to compute higher order "
                    "derivatives.", 1)

        flat_targets = []
        for t in nest.flatten(target, expand_composites=True):
            if not backprop_util.IsTrainable(t):
                logging.vlog(
                    logging.WARN, "The dtype of the target tensor must be "
                    "floating (e.g. tf.float32) when calling GradientTape.gradient, "
                    "got %r", t.dtype)
            if resource_variable_ops.is_resource_variable(t):
                with self:
                    t = ops.convert_to_tensor(t)
            flat_targets.append(t)

        flat_sources = nest.flatten(sources, expand_composites=True)
        flat_sources_raw = flat_sources
        flat_sources = [_handle_or_self(x) for x in flat_sources]
        for t in flat_sources_raw:
            if not backprop_util.IsTrainable(t):
                logging.vlog(
                    logging.WARN, "The dtype of the source tensor must be "
                    "floating (e.g. tf.float32) when calling GradientTape.gradient, "
                    "got %r", t.dtype)
            if getattr(t, "is_packed", False):
                raise ValueError(
                    "GradientTape.gradient is not supported on packed EagerTensors yet."
                )

        if output_gradients is not None:
            output_gradients = [
                None if x is None else ops.convert_to_tensor(x)
                for x in nest.flatten(output_gradients, expand_composites=True)
            ]

        flat_grad = imperative_grad.imperative_grad(
            self._tape,
            flat_targets,
            flat_sources,
            output_gradients=output_gradients,
            sources_raw=flat_sources_raw,
            unconnected_gradients=unconnected_gradients)

        if not self._persistent:
            # Keep track of watched variables before setting tape to None
            self._watched_variables = self._tape.watched_variables()
            self._tape = None

        grad = nest.pack_sequence_as(sources,
                                     flat_grad,
                                     expand_composites=True)
        return grad
def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
                     gate_gradients, aggregation_method, stop_gradients):
    """Implementation of gradients()."""
    if context.executing_eagerly():
        raise RuntimeError("tf.gradients not supported when eager execution "
                           "is enabled. Use tf.contrib.eager.GradientTape "
                           "instead.")
    ys = _AsList(ys)
    xs = _AsList(xs)
    stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
    if grad_ys is None:
        grad_ys = [None] * len(ys)
    else:
        grad_ys = _AsList(grad_ys)

    with ops.name_scope(
            name, "gradients",
            list(ys) + list(xs) + list(stop_gradients) +
            list(grad_ys)) as grad_scope:
        ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
        xs = [
            x.handle if resource_variable_ops.is_resource_variable(x) else x
            for x in xs
        ]
        xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs,
                                                                name="x",
                                                                as_ref=True)
        grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)

        # The approach we take here is as follows: Create a list of all ops in the
        # subgraph between the ys and xs.  Visit these ops in reverse order of ids
        # to ensure that when we visit an op the gradients w.r.t its outputs have
        # been collected.  Then aggregate these gradients if needed, call the op's
        # gradient function, and add the generated gradients to the gradients for
        # its input.

        # Initialize the pending count for ops in the connected subgraph from ys
        # to the xs.
        if len(ys) > 1:
            ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
        to_ops = [t.op for t in ys]
        from_ops = [t.op for t in xs]
        stop_gradient_ops = [t.op for t in stop_gradients]
        pending_count, loop_state = _PendingCount(ops.get_default_graph(),
                                                  to_ops, from_ops,
                                                  colocate_gradients_with_ops)

        # Iterate over the collected ops.
        #
        # grads: op => list of gradients received on each output endpoint of the
        # op.  The gradients for each endpoint are initially collected as a list.
        # When it is time to call the op's gradient function, for each endpoint we
        # aggregate the list of received gradients into a Add() Operation if there
        # is more than one.
        grads = {}

        # Add the initial gradients for the ys.
        for y, grad_y in zip(ys, grad_ys):
            _SetGrad(grads, y, grad_y)

        # Initialize queue with to_ops.
        queue = collections.deque()
        # Add the ops in 'to_ops' into the queue.
        to_ops_set = set()
        for op in to_ops:
            # 'ready' handles the case where one output gradient relies on
            # another output's gradient.
            # pylint: disable=protected-access
            ready = (pending_count[op._id] == 0)
            if ready and op._id not in to_ops_set:
                to_ops_set.add(op._id)
                queue.append(op)
            # pylint: enable=protected-access

        if loop_state:
            loop_exits = loop_state.ProcessUnusedLoopExits(
                pending_count, to_ops_set)
            for y in loop_exits:
                if _IsTrainable(y):
                    _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
                    queue.append(y.op)

        stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
        while queue:
            # generate gradient subgraph for op.
            op = queue.popleft()
            with _maybe_colocate_with(op, colocate_gradients_with_ops):
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=True)
                out_grads = _AggregatedGrads(grads, op, loop_state,
                                             aggregation_method)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=True)

                grad_fn = None
                # pylint: disable=protected-access
                func_call = None
                is_func_call = ops.get_default_graph()._is_function(op.type)
                has_out_grads = any(
                    isinstance(g, ops.Tensor) or g for g in out_grads)
                if has_out_grads and (op._id not in stop_ops):
                    if is_func_call:
                        func_call = ops.get_default_graph()._get_function(
                            op.type)
                        grad_fn = func_call.python_grad_func
                        # pylint: enable=protected-access
                    else:
                        # A grad_fn must be defined, either as a function or as None
                        # for ops that do not have gradients.
                        try:
                            grad_fn = ops.get_gradient_function(op)
                        except LookupError:
                            raise LookupError(
                                "No gradient defined for operation '%s' (op type: %s)"
                                % (op.name, op.type))
                if loop_state:
                    loop_state.EnterGradWhileContext(op, before=False)
                if (grad_fn or is_func_call) and has_out_grads:
                    # NOTE: If _AggregatedGrads didn't compute a value for the i'th
                    # output, it means that the cost does not depend on output[i],
                    # therefore dC/doutput[i] is 0.
                    for i, out_grad in enumerate(out_grads):
                        if (not isinstance(out_grad, ops.Tensor)
                                and not out_grad) and (
                                    (not grad_fn and is_func_call)
                                    or _IsTrainable(op.outputs[i])):
                            # Only trainable outputs or outputs for a function call that
                            # will use SymbolicGradient get a zero gradient. Gradient
                            # functions should ignore the gradient for other outputs.
                            # TODO(apassos) gradients of resource handles might be an
                            # issue here because of zeros.
                            if loop_state:
                                out_grads[i] = loop_state.ZerosLike(op, i)
                            else:
                                out_grads[
                                    i] = control_flow_ops.ZerosLikeOutsideLoop(
                                        op, i)
                    with ops.name_scope(op.name + "_grad"):
                        # pylint: disable=protected-access
                        with ops.get_default_graph()._original_op(op):
                            # pylint: enable=protected-access
                            if grad_fn:
                                # If grad_fn was found, do not use SymbolicGradient even for
                                # functions.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: grad_fn(op, *out_grads))
                            else:
                                # For function call ops, we add a 'SymbolicGradient'
                                # node to the graph to compute gradients.
                                in_grads = _MaybeCompile(
                                    grad_scope, op, func_call,
                                    lambda: _SymGrad(op, out_grads))
                            in_grads = _AsList(in_grads)
                            _VerifyGeneratedGradients(in_grads, op)
                            if gate_gradients and len(
                                [x for x in in_grads if x is not None]) > 1:
                                with ops.device(None):
                                    with ops.colocate_with(
                                            None, ignore_existing=True):
                                        in_grads = control_flow_ops.tuple(
                                            in_grads)
                    _LogOpGradients(op, out_grads, in_grads)
                else:
                    # If no grad_fn is defined or none of out_grads is available,
                    # just propagate a list of None backwards.
                    in_grads = [None] * len(op.inputs)
                for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
                    if in_grad is not None:
                        if (isinstance(in_grad, ops.Tensor)
                                and t_in.dtype != dtypes.resource):
                            try:
                                in_grad.set_shape(t_in.get_shape())
                            except ValueError:
                                raise ValueError(
                                    "Incompatible shapes between op input and calculated "
                                    "input gradient.  Forward operation: %s.  Input index: %d. "
                                    "Original input shape: %s.  "
                                    "Calculated input gradient shape: %s" %
                                    (op.name, i, t_in.shape, in_grad.shape))
                        _SetGrad(grads, t_in, in_grad)
                if loop_state:
                    loop_state.ExitGradWhileContext(op, before=False)

            # Update pending count for the inputs of op and enqueue ready ops.
            _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count,
                                          loop_state)

    if loop_state:
        loop_state.PostProcessing()
    return [_GetGrad(grads, x) for x in xs]