示例#1
0
def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
                       out_grads, skip_input_indices, forward_pass_name_scope):
  """Calls the gradient function of the op.

  Args:
    op_name: the name of the op to be differentiated.
    attr_tuple: the attrs, as a tuple.
    num_inputs: the number of inputs to the op.
    inputs: inputs to the original operation.
    outputs: outputs to the original operation.
    out_grads: gradients of the operation wrt its outputs.
    skip_input_indices: a tuple that is passed to the gradient function,
      indicating which inputs to skip calculating the gradient for
    forward_pass_name_scope: the namescope of the op in the forward pass.

  Returns:
    The gradients with respect to the inputs of the function, as a list.
  """
  mock_op = _MockOp(attr_tuple, inputs, outputs, op_name, skip_input_indices)
  grad_fn = ops._gradient_registry.lookup(op_name)  # pylint: disable=protected-access
  if grad_fn is None:
    return [None] * num_inputs

  # This does not work with v1 TensorArrays.
  if ops.executing_eagerly_outside_functions(
  ) or control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
    gradient_name_scope = "gradient_tape/"
    if forward_pass_name_scope:
      gradient_name_scope += forward_pass_name_scope + "/"
    with ops.name_scope(gradient_name_scope):
      return grad_fn(mock_op, *out_grads)
  else:
    return grad_fn(mock_op, *out_grads)
示例#2
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_distribution_strategy():
                distribute = distribution_strategy_context.get_distribution_strategy(
                )
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            # TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
            # because of a bug which leads cond_v2/while_v2 to skip rewriting them
            # creating conflicts.
            if (control_flow_util.EnableControlFlowV2(ops.get_default_graph())
                    or is_tpu_strategy):
                cm = contextlib.contextmanager(lambda: (yield))()
            else:
                cm = ops.colocate_with(variable)
            with cm:
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
示例#3
0
def _ProcessNewOps(graph):
    """Processes the newly-added TF_Operations in `graph`."""
    # Maps from a node to the names of the ops it's colocated with, if colocation
    # is specified in the attributes.
    colocation_pairs = {}

    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
        original_device = new_op.device
        new_op._set_device('')  # pylint: disable=protected-access
        colocation_names = _GetColocationNames(new_op)
        if colocation_names:
            colocation_pairs[new_op] = colocation_names
            # Don't set a device for this op, since colocation constraints override
            # device functions and the original device. Note that this op's device may
            # still be set by the loop below.
            # TODO(skyewm): why does it override the original device?
        else:
            with _MaybeDevice(original_device):
                graph._apply_device_functions(new_op)  # pylint: disable=protected-access

    # The following loop populates the device field of ops that are colocated
    # with another op.  This is implied by the colocation attribute, but we
    # propagate the device field for completeness.
    for op, coloc_op_list in colocation_pairs.items():
        coloc_device = None
        # Find any device in the list of colocated ops that have a device, if it
        # exists.  We assume that if multiple ops have devices, they refer to the
        # same device.  Otherwise, a runtime error will occur since the colocation
        # property cannot be guaranteed.  Note in TF2 colocations have been removed
        # from the public API and will be considered a hint, so there is no runtime
        # error.
        #
        # One possible improvement is to try to check for compatibility of all
        # devices in this list at import time here, which would require
        # implementing a compatibility function for device specs in python.
        for coloc_op_name in coloc_op_list:
            try:
                coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name)  # pylint: disable=protected-access
            except KeyError:
                # Do not error in TF2 if the colocation cannot be guaranteed
                if tf2.enabled() or control_flow_util.EnableControlFlowV2(
                        graph):
                    continue

                raise ValueError('Specified colocation to an op that '
                                 'does not exist during import: %s in %s' %
                                 (coloc_op_name, op.name))
            if coloc_op.device:
                coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
                break
        if coloc_device:
            op._set_device(coloc_device)  # pylint: disable=protected-access
示例#4
0
def control_flow_v2_enabled():  # pylint: disable=invalid-name
    """Returns `True` if v2 control flow is enabled.

  Note: v2 control flow is always enabled inside of tf.function.
  """
    return control_flow_util.EnableControlFlowV2(ops.get_default_graph())
示例#5
0
  def __init__(self,
               dtype,
               size=None,
               dynamic_size=None,
               clear_after_read=None,
               tensor_array_name=None,
               handle=None,
               flow=None,
               infer_shape=True,
               element_shape=None,
               colocate_with_first_write_call=True,
               name=None):
    """Construct a new TensorArray or wrap an existing TensorArray handle.

    A note about the parameter `name`:

    The name of the `TensorArray` (even if passed in) is uniquified: each time
    a new `TensorArray` is created at runtime it is assigned its own name for
    the duration of the run.  This avoids name collisions if a `TensorArray`
    is created within a `while_loop`.

    Args:
      dtype: (required) data type of the TensorArray.
      size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
        Required if handle is not provided.
      dynamic_size: (optional) Python bool: If true, writes to the TensorArray
        can grow the TensorArray past its initial size.  Default: False.
      clear_after_read: Boolean (optional, default: True).  If True, clear
        TensorArray values after reading them.  This disables read-many
        semantics, but allows early release of memory.
      tensor_array_name: (optional) Python string: the name of the TensorArray.
        This is used when creating the TensorArray handle.  If this value is
        set, handle should be None.
      handle: (optional) A `Tensor` handle to an existing TensorArray.  If this
        is set, tensor_array_name should be None. Only supported in graph mode.
      flow: (optional) A float `Tensor` scalar coming from an existing
        `TensorArray.flow`. Only supported in graph mode.
      infer_shape: (optional, default: True) If True, shape inference
        is enabled.  In this case, all elements must have the same shape.
      element_shape: (optional, default: None) A `TensorShape` object specifying
        the shape constraints of each of the elements of the TensorArray.
        Need not be fully defined.
      colocate_with_first_write_call: If `True`, the TensorArray will be
        colocated on the same device as the Tensor used on its first write
        (write operations include `write`, `unstack`, and `split`).  If `False`,
        the TensorArray will be placed on the device determined by the
        device context available during its initialization.
      name: A name for the operation (optional).

    Raises:
      ValueError: if both handle and tensor_array_name are provided.
      TypeError: if handle is provided but is not a Tensor.
    """
    if context.executing_eagerly():
      implementation = _EagerTensorArray
    else:
      if control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
        implementation = _GraphTensorArrayV2
      else:
        implementation = _GraphTensorArray
    self._implementation = implementation(
        dtype,
        size=size,
        dynamic_size=dynamic_size,
        clear_after_read=clear_after_read,
        tensor_array_name=tensor_array_name,
        handle=handle,
        flow=flow,
        infer_shape=infer_shape,
        element_shape=element_shape,
        colocate_with_first_write_call=colocate_with_first_write_call,
        name=name)

    self._implementation.parent = weakref.ref(self)