示例#1
0
def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
  """Captures a Tensor while building a graph mode function.

  Arguments:
    value: A Tensor object.
    dtype: The datatype of the value produced by the node in the graph.
    name:  Name of the node in the graph.
    as_ref: Ignored (required by register_tensor_conversion_function).

  Returns:
    Returns a constant (the current value of the tensor) if capturing
    is not enabled. A placeholder which will have the value of the
    tensor at runtime otherwise.
  """
  if context.in_eager_mode():
    return value
  _ = as_ref
  tensor_map = _scoped_captures.tensors
  if tensor_map is None:
    # Capturing is not enabled.
    return constant_op.constant(value.numpy())
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value], [],
                        lambda x: x)
  return captured_value
示例#2
0
def _watch_with_tape(tape, resource_variable):
  """Wraps a watched Tensor and keeps track of it in the implicit tape."""
  tensor = resource_variable.handle
  w = _watch_with_tape_internal(tape, tensor)
  if ag_core.isnode(tape):
    tape.value.variables[ops.tensor_id(tensor)] = resource_variable
    tape.value.tensors[ops.tensor_id(tensor)] = w
示例#3
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      handle_data = value._handle_data  # pylint: disable=protected-access
      captured_value._handle_data = handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        with errors.raise_exception_on_not_ok_status() as status:
          pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
              captured_value._op._graph._c_graph,  # pylint: disable=protected-access
              captured_value._as_tf_output(),  # pylint: disable=protected-access
              shapes,
              ranks,
              types,
              status)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
示例#4
0
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
  """Captures a tfe Tensor while building a graph mode function.

  Creates a placeholder to pass the tensor as an argument.

  Arguments:
    value: A tfe.Tensor object
    dtype: The datatype of the value produced by the node in the graph.
    name:  Name of the node in the graph.
    as_ref: Ignored (required by register_tensor_conversion_function).

  Returns:
    A placeholder which will, at runtime, have the value of this tensor.

  Raises:
    ValueError: if called outside a defun context.
  """
  if context.in_eager_mode():
    return value
  _ = as_ref
  tensor_map = _scoped_captures.tensors
  if tensor_map is None:
    raise ValueError(
        "Trying to use tfe.Tensor objects in a graph outside graph mode. "
        "To build a graph use tfe.defun or tfe.make_template.")
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  return captured_value
示例#5
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
示例#6
0
def _ensure_unique_tensor_objects(parameter_positions, args):
    """Make each of the parameter_positions in args a unique ops.Tensor object.

  Ensure that each parameter is treated independently.
  For example:

  def f(x, y): return x * y
  g = gradients_function(f)
  one = tf.constant(1.)

  g(one, one) should return [1., 1.]
  (even though the two arguments are the same Tensor object).

  Args:
    parameter_positions: List of indices into args defining the arguments to
      differentiate against.
    args: A list of arguments to the function to be differentiated.

  Returns:
    args, possibly edited in-place.
  """
    s = set()
    for (i, t) in enumerate(args):
        if i in parameter_positions:
            tid = ops.tensor_id(t)
            if tid in s:
                args[i] = gen_array_ops.identity(args[i])
            else:
                s.add(tid)
    return args
示例#7
0
def _ensure_unique_tensor_objects(parameter_positions, args):
  """Make each of the parameter_positions in args a unique ops.Tensor object.

  Ensure that each parameter is treated independently.
  For example:

  def f(x, y): return x * y
  g = gradients_function(f)
  one = tf.constant(1.)

  g(one, one) should return [1., 1.]
  (even though the two arguments are the same Tensor object).

  Args:
    parameter_positions: List of indices into args defining the arguments to
      differentiate against.
    args: A list of arguments to the function to be differentiated.

  Returns:
    args, possibly edited in-place.
  """
  s = set()
  for (i, t) in enumerate(args):
    if i in parameter_positions:
      tid = ops.tensor_id(t)
      if tid in s:
        args[i] = gen_array_ops.identity(args[i])
      else:
        s.add(tid)
  return args
示例#8
0
def execute(op_name, num_outputs, inputs, attrs=None, name=None):
  """Execute a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    num_outputs: The number of outputs of the operation to fetch.
                 (Explicitly provided instead of being inferred for performance
                 reasons).
    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
      a value which can be passed to the Tensor constructor to create one.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    name: Customized name for the operation.

  Returns:
    None if there are no outputs, a single Tensor object if there is one output
    and a list of Tensor objects if there are multiple outputs.

  Raises:
    An exception on error.
  """
  ctx = context.get_default_context()
  # TODO(apassos) move this to convert_to_tensor
  inputs = [ag_core.getval(x) for x in inputs]
  # pylint: disable=protected-access
  input_handles = [c._handle for c in inputs]
  device_name = ctx.device_name
  try:
    outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                            str(op_name), input_handles, attrs,
                                            num_outputs)
    # pylint: enable=protected-access
  except core._NotOkStatusException as e:  # pylint: disable=protected-access
    if name is not None:
      message = e.message + " name: " + name
    else:
      message = e.message
    raise core._status_to_exception(e.code, message)  # pylint: disable=protected-access
  # pylint: enable=protected-access

  tensors = [tensor._tensor_from_handle(x) for x in outh]  # pylint: disable=protected-access
  # TODO(alive, cais): Use the execution callback mechanism.
  if core.active_trace() is not None:
    trace_name = name if name else op_name
    for t in tensors:
      # pylint: disable=protected-access
      core.active_trace().record_tensor(trace_name,
                                        ops.tensor_id(t),
                                        t._device_name(),
                                        t.shape.num_elements())
      # pylint: enable=protected-access

  # TODO(cais): Optimize this, perhaps by replacing this execute function with
  # a different one when there are execution callback(s).
  for callback in ctx.post_execution_callbacks:
    callback(op_name, name, attrs, inputs, tensors)

  return tensors
示例#9
0
  def add_capture(self, tensor, placeholder):
    """Capture a specific tensor and utilize the provided placeholder.

    Args:
      tensor: Tensor to captures.
      placeholder: Provided placeholder for the tensor.
    """
    self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
    self.inputs.append(placeholder)
示例#10
0
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
    """Filters the tape to only include relevant entries and counts tensor usages.

  Args:
    target: the target to optimize.
    tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
    op_to_entry: Map from op id to a tape.TapeEntry object
    id_sources: the ids of the sources wrt the gradient is being taken.

  Returns:
    usage counts (how many entries downstream from a tensor use it)
    op_to_entry_map: entry map (a filtered tape, with only the relevant
     entries),
    missing: map from tensor id to how many downstream gradients still need
     to be computed before this tensor's gradient can be computed.
  """
    if isinstance(target, (ops.Tensor)):
        tensor_stack = [ops.tensor_id(target)]
    else:
        tensor_stack = list([ops.tensor_id(x) for x in target])
    tensor_usage_counts = {}
    o_to_e = {}  # Copy of just the bits we need from op_to_entry
    while tensor_stack:
        t = tensor_stack.pop()
        op = tensor_to_op[t]
        # op is None if the tensor is a source (i.e. was watched directly)
        if op is None or op in o_to_e:
            continue
        op_trace = op_to_entry[op]
        o_to_e[op] = op_trace
        for i in op_trace.inputs:
            it = ops.tensor_id(i)
            if it in tensor_usage_counts:
                tensor_usage_counts[it] += 1
            else:
                tensor_usage_counts[it] = 1
                if it not in id_sources and it in tensor_to_op:
                    tensor_stack.append(it)
    op_missing_tensor_counts = collections.defaultdict(int)
    for t in tensor_usage_counts:
        if t in tensor_to_op and tensor_to_op[t] is not None:
            op_missing_tensor_counts[tensor_to_op[t]] += 1
    return tensor_usage_counts, o_to_e, op_missing_tensor_counts
示例#11
0
def _make_inputs_match(branch_graphs, branch_inputs):
    """Modifies branch_graphs so they have the same input signature.

  This method reorders and/or adds parameters to each graph in branch_graphs so
  they have the same input signature, and updates the 'inputs' and 'captured'
  fields of each graph accordingly. It uses the input tensors from the outer
  graph to avoid duplicating shared arguments.

  Args:
    branch_graphs: a `list` of `FuncGraph`
    branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The
      inputs for the corresponding graph in `branch_graphs`.

  Returns:
    A new list of Tensors from the outer graph that are the new inputs for each
    branch_graph. This is a deduped version of `sum(branch_inputs)`.
  """
    assert len(branch_graphs) == len(branch_inputs)
    added_inputs = set()
    new_inputs = []
    for branch_in in branch_inputs:
        for tensor in branch_in:
            tensor_id = ops.tensor_id(tensor)
            if tensor_id not in added_inputs:
                added_inputs.add(tensor_id)
                new_inputs.append(tensor)

    for branch_graph, branch_in in zip(branch_graphs, branch_inputs):
        input_ids = [ops.tensor_id(t) for t in branch_in]
        branch_input_to_param = dict(zip(input_ids, branch_graph.inputs))
        input_list = []
        for in_t in new_inputs:
            param = branch_input_to_param.get(ops.tensor_id(in_t))
            if param is None:
                param = _create_dummy_input(branch_graph, in_t)
            input_list.append(param)

        branch_graph.inputs = input_list

        # Rewrite the FuncGraphs' state to reflect the new inputs.
        branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))

    return new_inputs
示例#12
0
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
  """Filters the tape to only include relevant entries and counts tensor usages.

  Args:
    target: the target to optimize.
    tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
    op_to_entry: Map from op id to a tape.TapeEntry object
    id_sources: the ids of the sources wrt the gradient is being taken.

  Returns:
    usage counts (how many entries downstream from a tensor use it)
    op_to_entry_map: entry map (a filtered tape, with only the relevant
     entries),
    missing: map from tensor id to how many downstream gradients still need
     to be computed before this tensor's gradient can be computed.
  """
  if isinstance(target, (ops.Tensor)):
    tensor_stack = [ops.tensor_id(target)]
  else:
    tensor_stack = list([ops.tensor_id(x) for x in target])
  tensor_usage_counts = {}
  o_to_e = {}  # Copy of just the bits we need from op_to_entry
  while tensor_stack:
    t = tensor_stack.pop()
    op = tensor_to_op[t]
    # op is None if the tensor is a source (i.e. was watched directly)
    if op is None or op in o_to_e:
      continue
    op_trace = op_to_entry[op]
    o_to_e[op] = op_trace
    for it in op_trace.input_ids:
      if it in tensor_usage_counts:
        tensor_usage_counts[it] += 1
      else:
        tensor_usage_counts[it] = 1
        if it not in id_sources and it in tensor_to_op:
          tensor_stack.append(it)
  op_missing_tensor_counts = collections.defaultdict(int)
  for t in tensor_usage_counts:
    if t in tensor_to_op and tensor_to_op[t] is not None:
      op_missing_tensor_counts[tensor_to_op[t]] += 1
  return tensor_usage_counts, o_to_e, op_missing_tensor_counts
示例#13
0
 def _capture_helper(self, tensor, name):
   capture = self._captures.get(ops.tensor_id(tensor))
   if capture is None:
     placeholder = _create_substitute_placeholder(
         tensor, name=name, dtype=tensor.dtype)
     self.add_capture(tensor, placeholder)
   else:
     placeholder = capture[1]
   tape.record_operation("captured_value", [placeholder], [tensor],
                         lambda x: [x])
   return placeholder
  def testAssignMirroredVarCrossDeviceContext(self, distribution):
    def var_fn():
      return variable_scope.variable(1.0, name="foo")

    with distribution.scope():
      mirrored_var = distribution.extended.call_for_each_replica(var_fn)
      self.assertIsInstance(mirrored_var, values.MirroredVariable)
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(1.0, self.evaluate(mirrored_var))
      self.assertIsNotNone(ops.tensor_id(mirrored_var))
      mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
      self.assertEqual(6.0, mirrored_var_result)
示例#15
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      if ops._USE_C_SHAPES:  # pylint: disable=protected-access
        if isinstance(value, ops.EagerTensor):
          handle_data = value._handle_data  # pylint: disable=protected-access
        else:
          handle_data = resource_variable_ops.get_resource_handle_data(value)
      else:
        handle_data = value._handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
          pywrap_tensorflow.SetResourceHandleShapeAndType(
              captured_value.graph._c_graph, captured_value._as_tf_output(),
              handle_data.SerializeToString())
        else:
          captured_value._handle_data = handle_data
        # pylint: enable=protected-access
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
            captured_value._op._graph._c_graph,  # pylint: disable=protected-access
            captured_value._as_tf_output(),  # pylint: disable=protected-access
            shapes, ranks, types)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
示例#16
0
def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      if ops._USE_C_SHAPES:  # pylint: disable=protected-access
        if isinstance(value, ops.EagerTensor):
          handle_data = value._handle_data  # pylint: disable=protected-access
        else:
          handle_data = resource_variable_ops.get_resource_handle_data(value)
      else:
        handle_data = value._handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
          pywrap_tensorflow.SetResourceHandleShapeAndType(
              captured_value.graph._c_graph, captured_value._as_tf_output(),
              handle_data.SerializeToString())
        else:
          captured_value._handle_data = handle_data
        # pylint: enable=protected-access
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
            captured_value._op._graph._c_graph,  # pylint: disable=protected-access
            captured_value._as_tf_output(),  # pylint: disable=protected-access
            shapes, ranks, types)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
示例#17
0
def _initial_gradients(target, output_gradients, tensor_usage_counts):
    """Computes the initial gradients for each Tensor."""
    # Initialize the backprop stack
    gradients = collections.defaultdict(list)
    if isinstance(target, ops.Tensor):
        if output_gradients is not None:
            output_gradient = output_gradients
        else:
            output_gradient = array_ops.ones_like(target)
        gradients[ops.tensor_id(target)].append(output_gradient)
    else:
        for i, t in enumerate(target):
            if ops.tensor_id(t) in tensor_usage_counts:
                # Can't provide a gradient of something we're trying to differentiate
                assert output_gradients is None or output_gradients[i] is None
            else:
                if output_gradients is None or output_gradients[i] is None:
                    out_grad = array_ops.ones_like(t)
                else:
                    out_grad = output_gradients[i]
                gradients[ops.tensor_id(t)].append(out_grad)
    return gradients
示例#18
0
def _initial_gradients(target, output_gradients, tensor_usage_counts):
  """Computes the initial gradients for each Tensor."""
  # Initialize the backprop stack
  gradients = collections.defaultdict(list)
  if isinstance(target, ops.Tensor):
    if output_gradients is not None:
      output_gradient = output_gradients
    else:
      output_gradient = array_ops.ones_like(target)
    gradients[ops.tensor_id(target)].append(output_gradient)
  else:
    for i, t in enumerate(target):
      if ops.tensor_id(t) in tensor_usage_counts:
        # Can't provide a gradient of something we're trying to differentiate
        assert output_gradients is None or output_gradients[i] is None
      else:
        if output_gradients is None or output_gradients[i] is None:
          out_grad = array_ops.ones_like(t)
        else:
          out_grad = output_gradients[i]
        gradients[ops.tensor_id(t)].append(out_grad)
  return gradients
示例#19
0
 def capture_eager_tensor(self, tensor, name):
   capture = self._captures.get(ops.tensor_id(tensor))
   if capture is None:
     # We clear all control dependencies and place the Const op on the same
     # device as the source tensor. The device placement may be relaxed at
     # a later date.
     with ops.control_dependencies(None), self.device(tensor.device):
       graph_const = constant_op.constant(tensor.numpy(), dtype=tensor.dtype,
                                          shape=tensor.shape, name=name)
     self.add_capture(tensor, graph_const)
   else:
     graph_const = capture[1]
   tape.record_operation("captured_value", [graph_const], [tensor],
                         lambda x: [x])
   return graph_const
示例#20
0
def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False):
    """Captures a tfe Tensor while building a graph mode function.

  Creates a placeholder to pass the tensor as an argument.

  Arguments:
    value: A tfe.Tensor object
    dtype: The datatype of the value produced by the node in the graph.
    name:  Name of the node in the graph.
    as_ref: Ignored (required by register_tensor_conversion_function).

  Returns:
    A placeholder which will, at runtime, have the value of this tensor.

  Raises:
    ValueError: if called outside a defun context.
  """
    if context.in_eager_mode():
        return value
    _ = as_ref
    tensor_map = _scoped_captures.tensors
    if tensor_map is None:
        raise ValueError(
            "Trying to use tfe.Tensor objects in a graph outside graph mode. "
            "To build a graph use tfe.defun or tfe.make_template.")
    captured_value = tensor_map.get(ops.tensor_id(value), None)
    if captured_value is None:
        captured_value = graph_placeholder(dtype=dtype or value.dtype,
                                           shape=value.shape,
                                           name=name)
        if captured_value.dtype == dtypes.resource:
            captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
        tensor_map[ops.tensor_id(value)] = (value, captured_value)
    else:
        captured_value = captured_value[1]
    return captured_value
示例#21
0
    def _backprop_call(self, args):
        """Calls the wrapped function and records the result on a tape."""
        all_args = args + self._extra_inputs
        signature = self._forward_fdef.definition.signature
        if context.in_graph_mode():
            g = ops.get_default_graph()
            g._add_function(self._forward_fdef)  # pylint: disable=protected-access
            unwrapped_args = [ag_core.getval(x) for x in all_args]
            op = g.create_op(
                signature.name,
                [ops.convert_to_tensor(x) for x in unwrapped_args],
                [dtypes.DType(x.type) for x in signature.output_arg],
                op_def=signature,
                name="FunctionCall",
                compute_shapes=False)
            outputs = op.outputs
            outputs = [outputs] if isinstance(outputs,
                                              (tensor.Tensor, ops.Tensor,
                                               type(None))) else list(outputs)
            for i, s in enumerate(self._output_shapes):
                outputs[i].set_shape(s)
        else:
            outputs = execute.execute(str(signature.name),
                                      num_outputs=len(signature.output_arg),
                                      inputs=all_args)
        real_outputs = outputs[:len(self._returns)]
        side_outputs = outputs[len(self._returns):]
        watched_extra_inputs = []
        for t in self._extra_inputs:
            tid = ops.tensor_id(t)
            for t in tape._tape_stack.stack:  # pylint: disable=protected-access
                w = t.value.tensors.get(tid, None)
                if w is not None:
                    watched_extra_inputs.append(w)
                    break
            else:  # Note: for-else here done on purpose
                watched_extra_inputs.append(t)

        def backward_function_wrapper(*outputs):
            outputs = outputs[len(real_outputs):]
            return self._backward_function(*outputs)

        real_outputs = tape.record_operation(real_outputs,
                                             (args + watched_extra_inputs),
                                             side_outputs,
                                             backward_function_wrapper)

        return self._build_call_outputs(self._returns, real_outputs)
示例#22
0
def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
    """Execute a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    num_outputs: The number of outputs of the operation to fetch.
                 (Explicitly provided instead of being inferred for performance
                 reasons).
    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
      a value which can be passed to the Tensor constructor to create one.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    ctx: The value of context.context().
    name: Customized name for the operation.

  Returns:
    List of output Tensor objects. The list is empty if there are no outputs

  Raises:
    An exception on error.
  """
    device_name = ctx.device_name
    # pylint: disable=protected-access
    try:
        tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,
                                                   op_name, inputs, attrs,
                                                   num_outputs)
    except core._NotOkStatusException as e:
        if name is not None:
            message = e.message + " name: " + name
        else:
            message = e.message
        six.raise_from(core._status_to_exception(e.code, message), None)

    # TODO(alive, cais): Use the execution callback mechanism.
    if core.active_trace() is not None:
        for t in tensors:
            core.active_trace().record_tensor(op_name, ops.tensor_id(t),
                                              t.device, t.shape.num_elements())
    # pylint: enable=protected-access

    # TODO(cais): Optimize this, perhaps by replacing this execute function with
    # a different one when there are execution callback(s).
    for callback in ctx.post_execution_callbacks:
        callback(op_name, name, attrs, inputs, tensors)

    return tensors
示例#23
0
  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      unwrapped_args = [ag_core.getval(x) for x in all_args]
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          signature.name,
          num_outputs=len(signature.output_arg),
          inputs=all_args)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]
    watched_extra_inputs = []
    for t in self._extra_inputs:
      tid = ops.tensor_id(t)
      for t in tape._tape_stack.stack:  # pylint: disable=protected-access
        w = t.value.tensors.get(tid, None)
        if w is not None:
          watched_extra_inputs.append(w)
          break
      else:  # Note: for-else here done on purpose
        watched_extra_inputs.append(t)

    def backward_function_wrapper(*outputs):
      outputs = outputs[len(real_outputs):]
      return self._backward_function(*outputs)
    real_outputs = tape.record_operation(
        real_outputs,
        (args + watched_extra_inputs),
        side_outputs,
        backward_function_wrapper)

    return self._build_call_outputs(self._returns, real_outputs)
示例#24
0
 def _setup_functions_captures(self):
     """Setup captures and variables in restored functions."""
     concrete_functions = sorted(self._proto.concrete_functions.items())
     for name, proto in concrete_functions:
         concrete_function = self._concrete_functions[name]
         bound_inputs = [
             self._get_tensor_from_node(node_id)
             for node_id in proto.bound_inputs
         ]
         bound_variables = [
             self._nodes[node_id] for node_id in proto.bound_inputs
             if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
         ]
         # TODO(andresp): This is only injecting the captured inputs into the
         # concrete function, note that we did not modify the FuncGraph
         # itself.
         concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
         concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
         if bound_inputs:
             for bound_input, internal_capture in zip(
                     bound_inputs,
                     concrete_function.inputs[-len(bound_inputs):]):
                 if ds_values.is_distributed_variable(bound_input):
                     concrete_function.graph.capture_distributed_variable(
                         bound_input, internal_capture)
                 else:
                     concrete_function.graph._captures[ops.tensor_id(
                         bound_input)] = (  # pylint: disable=protected-access
                             bound_input, internal_capture)
                     if internal_capture.dtype == dtypes.resource:
                         if resource_variable_ops.is_resource_variable(
                                 bound_input):
                             try:
                                 handle = bound_input.handle
                             except ValueError:
                                 # For mirrored variables we'll copy handle data for components
                                 # as they get captured.
                                 pass
                             else:
                                 custom_gradient.copy_handle_data(
                                     handle, internal_capture)
                         else:
                             custom_gradient.copy_handle_data(
                                 bound_input, internal_capture)
                     # Setting "captures" first means "capture" won't create a new
                     # placeholder for this input.
                     concrete_function.graph.capture(bound_input)
示例#25
0
    def _process_switch(self, switch_op, ops_which_must_run,
                        last_write_to_resource, merge_for_resource):
        """Processes a switch node for a resource input.

    When tensorflow creates a cond, it creates a control flow context for each
    branch of the cond. Each external tensor accessed by that branch is routed
    through a switch op, which gets created in the graph _after_ the op which
    uses that tensor get created.

    If the resource comes from another switch op we process that one first.

    _process_switch creates a corresponding merge node for the switch node. This
    merge node is added to the outer control flow context of the switch
    node. We also ensure that:

      1. The switch node executes after the previous op which used the resource
         tensor

      2. Any op which uses a resource output of the switch node executes before
         the merge for the switch node.

      3. The next op which uses the input resource to the switch node (which
         might be another switch node for the other branch of the conditional)
         will execute after the merge node is done.

      4. The merge node is marked as must_run so it will run even if no
         subsequent operation uses the resource.

    Args:
      switch_op: the switch op to be processed
      ops_which_must_run: the set of ops which must run
      last_write_to_resource: map from resource tensor to last op updating
        it
      merge_for_resource: map from resource tensor to merge which must follow
        all usages of it.
    """
        # pylint: disable=protected-access
        inp = switch_op.inputs[0]
        input_id = ops.tensor_id(inp)
        if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
            self._process_switch(inp.op, ops_which_must_run,
                                 last_write_to_resource, merge_for_resource)
        output = switch_op.outputs[0]
        output_id = ops.tensor_id(output)
        if output_id in merge_for_resource:
            return
        new_merge = control_flow_ops.merge(switch_op.outputs,
                                           name="artificial_merge")
        new_merge[0].op._control_flow_context = (
            switch_op._control_flow_context.outer_context)
        # Ensures the merge always runs
        ops_which_must_run.add(new_merge[0].op)
        if input_id in last_write_to_resource:
            # Ensures the switch executes after the previous op using the resource.
            switch_op._add_control_input(last_write_to_resource[input_id])
        # Ensure the next op outside the cond happens after the merge.
        last_write_to_resource[input_id] = new_merge[0].op
        if input_id in merge_for_resource:
            merge_for_resource[input_id]._add_control_input(new_merge[0].op)
        for o in switch_op.outputs:
            # Ensures the merge will execute after all ops inside the cond
            merge_for_resource[ops.tensor_id(o)] = new_merge[0].op
def function_def_to_graph(fdef,
                          structured_input_signature=None,
                          structured_outputs=None,
                          input_shapes=None):
  """Converts a FunctionDef to a FuncGraph (sub-class Graph).

  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
  The input tensors are represented as placeholders.

  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
  by the caller.

  Args:
    fdef: FunctionDef.
    structured_input_signature: Optional. The structured input signature to
      use for initializing the FuncGraph. See the docstring for FuncGraph for
      more information.
    structured_outputs: Optional. The structured outputs to use for
      initializing the FuncGraph. See the docstring for FuncGraph for more
      information.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. Defaults to the function's "_input_shapes" attribute. If
      specified, its length must match length of `fdef.signature.input_arg`. If
      a shape is None, the corresponding input placeholder will have unknown
      shape.

  Returns:
    A FuncGraph.
  """
  func_graph = FuncGraph(fdef.signature.name,
                         structured_input_signature=structured_input_signature,
                         structured_outputs=structured_outputs)
  if input_shapes is None:
    input_shapes_attr = fdef.attr.get("_input_shapes", None)
    if input_shapes_attr is not None:
      raw_input_shapes = input_shapes_attr.list.shape

      # Replace resource handle shapes in the inputs to disable shape inference.
      # Setting the shape to either the variable handle shape (which is always
      # `[]`) or the variable shape can cause shape inference issues.
      input_shapes = []
      for input_shape, arg_def in zip(raw_input_shapes,
                                      fdef.signature.input_arg):
        if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
          input_shapes.append(None)
        else:
          input_shapes.append(input_shape)

  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
      fdef, input_shapes)

  with func_graph.as_default():
    # Add all function nodes to the graph.
    importer.import_graph_def_for_function(graph_def, name="")

    # Initialize fields specific to FuncGraph.

    # inputs
    input_tensor_names = [
        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
    ]
    func_graph.inputs = [
        func_graph.get_tensor_by_name(name) for name in input_tensor_names
    ]

    # outputs
    output_tensor_names = [
        nested_to_flat_tensor_name[fdef.ret[arg.name]]
        for arg in fdef.signature.output_arg
    ]
    func_graph.outputs = [
        func_graph.get_tensor_by_name(name) for name in output_tensor_names
    ]
    func_graph.control_outputs = [
        func_graph.get_operation_by_name(fdef.control_ret[ret_name])
        for ret_name in fdef.signature.control_output
    ]

    _set_handle_data(func_graph, fdef)

    for node in graph_def.node:
      output_shapes = node.attr.get("_output_shapes", None)
      if output_shapes is not None:
        op = func_graph.get_operation_by_name(node.name)
        # _output_shapes for functions can sometimes be too long because the
        # output-intermediates-for-gradients version of the function was
        # substituted before saving. We'll accept that here. (See b/133666530).
        for output_index, shape in enumerate(
            output_shapes.list.shape[:len(op.outputs)]):
          op.outputs[output_index].set_shape(shape)
    output_names = {}
    for ret_arg_def, tensor_name in zip(
        fdef.signature.output_arg, output_tensor_names):
      output_names[ops.tensor_id(
          func_graph.get_tensor_by_name(tensor_name))] = (
              ret_arg_def.name)
    func_graph._output_names = output_names  # pylint: disable=protected-access
  return func_graph
示例#27
0
  def __exit__(self, unused_type, unused_value, unused_traceback):
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Graph changed while trying to add control dependencies.")

    # pylint: disable=protected-access
    if hasattr(self._graph, "outer_graph"):
      outer_val = self._graph.outer_graph._add_control_dependencies
      self._graph._add_control_dependencies = outer_val
    else:
      self._graph._add_control_dependencies = False
    # pylint: enable=protected-access

    # map from resource tensor to the last op which used it
    last_op_using_resource_tensor = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_op_using_resource_tensor).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if control_flow_util.IsInWhileLoop(op):
        continue
      control_inputs = set()
      # Ensure stateful ops run
      if op_def_registry.get(op.type) is None or op_is_stateful(op):
        ops_which_must_run.add(op)
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)  # pylint: disable=protected-access
          for inp in o.inputs:
            input_id = ops.tensor_id(inp)
            if input_id in last_op_using_resource_tensor:
              last_op_using_resource_tensor[input_id] = op
        ops_which_must_run = set([op])
        continue

      resource_inputs = set()
      # Check for any resource inputs. If we find any, we update control_inputs
      # and last_op_using_resource_tensor.
      for inp in op.inputs:
        if inp.dtype != dtypes_module.resource:
          continue

        input_id = ops.tensor_id(inp)

        # If the op receives the same resource tensor twice as an input, we skip
        # to avoid the op getting a control dependency on itself.
        if input_id in resource_inputs:
          continue

        resource_inputs.add(input_id)
        # Deal with switches, finally.
        if inp.op.type == "Switch":
          self._process_switch(inp.op, ops_which_must_run,
                               last_op_using_resource_tensor,
                               merge_for_resource)
        # Ensure uses of resources are serialized
        if input_id in last_op_using_resource_tensor:
          if (last_op_using_resource_tensor[input_id]._control_flow_context  # pylint: disable=protected-access
              is op._control_flow_context):  # pylint: disable=protected-access
            control_inputs.add(last_op_using_resource_tensor[input_id])
        # Ensure merges happen after the closing of a cond block
        if input_id in merge_for_resource:
          merge_for_resource[input_id]._add_control_input(op)  # pylint: disable=protected-access
        last_op_using_resource_tensor[input_id] = op

      if (op_is_stateful(op) and not resource_inputs
          and op._control_flow_context is None):  # pylint: disable=protected-access
        if None in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
        last_op_using_resource_tensor[None] = op
      control_inputs = [c for c in control_inputs
                        if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access

    # Ensure all ops which must run do run
    self.ops_which_must_run.update(ops_which_must_run)
    for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
      if self.ops_which_must_run:
        r.op._add_control_inputs(  # pylint: disable=protected-access
            [o for o in self.ops_which_must_run
             if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access
示例#28
0
def function_def_to_graph(fdef, input_shapes=None, copy_functions=True):
  """Converts a FunctionDef to a FuncGraph (sub-class Graph).

  The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
  The input tensors are represented as placeholders.

  Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
  by the caller.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. Defaults to the function's "_input_shapes" attribute. If
      specified, its length must match length of `fdef.signature.input_arg`. If
      a shape is None, the corresponding input placeholder will have unknown
      shape.
    copy_functions: Whether to copy all functions that exists in default graph
      (independently of being used or not) to the created FuncGraph. Functions
      required for graph import will be copied regardless.

  Returns:
    A FuncGraph.
  """
  func_graph = FuncGraph(fdef.signature.name)
  if input_shapes is None:
    input_shapes_attr = fdef.attr.get("_input_shapes", None)
    if input_shapes_attr is not None:
      input_shapes = input_shapes_attr.list.shape
  graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
      fdef, input_shapes, copy_functions)

  with func_graph.as_default():
    # Add all function nodes to the graph.
    importer.import_graph_def_for_function(graph_def, name="")

    # Initialize fields specific to FuncGraph.

    # inputs
    input_tensor_names = [
        nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
    ]
    func_graph.inputs = [
        func_graph.get_tensor_by_name(name) for name in input_tensor_names
    ]

    # outputs
    output_tensor_names = [
        nested_to_flat_tensor_name[fdef.ret[arg.name]]
        for arg in fdef.signature.output_arg
    ]
    func_graph.outputs = [
        func_graph.get_tensor_by_name(name) for name in output_tensor_names
    ]
    func_graph.control_outputs = [
        func_graph.get_operation_by_name(fdef.control_ret[ret_name])
        for ret_name in fdef.signature.control_output
    ]
    for node in graph_def.node:
      output_shapes = node.attr.get("_output_shapes", None)
      if output_shapes is not None:
        op = func_graph.get_operation_by_name(node.name)
        # _output_shapes for functions can sometimes be too long because the
        # output-intermediates-for-gradients version of the function was
        # substituted before saving. We'll accept that here. (See b/133666530).
        for output_index, shape in enumerate(
            output_shapes.list.shape[:len(op.outputs)]):
          op.outputs[output_index].set_shape(shape)
    output_names = {}
    for ret_arg_def, tensor_name in zip(
        fdef.signature.output_arg, output_tensor_names):
      output_names[ops.tensor_id(
          func_graph.get_tensor_by_name(tensor_name))] = (
              ret_arg_def.name)
    func_graph._output_names = output_names  # pylint: disable=protected-access
  return func_graph
示例#29
0
def imperative_grad(target, sources, output_gradients=None):
    """Computes gradients from the imperatively defined tape on top of the stack.

  Works by filtering the tape, computing how many downstream usages are of each
  tensor and entry, and repeatedly applying backward functions until we have
  gradients for all sources.

  Args:
   target: either a Tensor or list of Tensors to be differentiated.
   sources: list of Tensors for which we want gradients
   output_gradients: if not None, a list of gradient provided for each Target,
    or None if we are to use the target's computed downstream gradient.

  Returns:
   the gradient wrt each of the sources.

  Raises:
    RuntimeError: if something goes wrong.
    ValueError: if there is no sequence of differentiable operations connecting
     a source and any target Tensor. This can happen either if the target is
     not computed based on the source, if the tracing was set up incorrectly,
     or if only non-differentiable functions of the source were used in the
     computation of target.
  """
    if not tape._tape_stack.stack:  # pylint: disable=protected-access
        raise RuntimeError("Computing a gradient with no tape present")
    bp_tape = tape.pop_tape()
    tensor_to_op, op_to_entry = bp_tape.export()
    # This overwrites the op_to_entry variable, which will release all memory used
    # to keep traces that are irrelevant to the gradient computation we're doing
    # here.
    id_sources = [ops.tensor_id(t) for t in sources]
    tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
        target, tensor_to_op, op_to_entry, id_sources)
    ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
    gradients = _initial_gradients(target, output_gradients,
                                   tensor_usage_counts)
    # Now exhaust the backprop stack
    while ready_ops:
        op = ready_ops.pop()
        op_trace = op_to_entry.pop(op)
        out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
        for i in range(len(out_gradients)):
            if out_gradients[i] is None:
                # TODO(apassos) this should be in the right device
                none_indices = _grad_fn_accepts_none_for_indices.get(
                    op_trace.op_type, None)
                if none_indices is None or i not in none_indices:
                    out_gradients[i] = array_ops.zeros(
                        *op_trace.output_shape_and_dtype[i])
            else:
                out_gradients[i] = _aggregate_grads(out_gradients[i])

        in_gradients = op_trace.backward_function(*(out_gradients +
                                                    op_trace.side_outputs))
        in_gradients = ([in_gradients] if isinstance(
            in_gradients,
            (ops.Tensor, ops.IndexedSlices, type(None))) else in_gradients)
        for i, t in enumerate(op_trace.input_ids):
            if in_gradients[i] is not None:
                gradients[t].append(in_gradients[i])
            if tensor_usage_counts.get(t, 0) > 0:
                tensor_usage_counts[t] -= 1
                if (t in tensor_to_op and tensor_usage_counts[t] == 0
                        and t not in id_sources):
                    in_op = tensor_to_op[t]
                    if in_op is None:
                        continue
                    if op_missing_tensor.get(in_op, 0) > 0:
                        op_missing_tensor[in_op] -= 1
                        if op_missing_tensor.get(in_op, 0) == 0:
                            ready_ops.append(in_op)
    result = []
    for i, s in enumerate(sources):
        g = gradients.get(ops.tensor_id(s), None)
        if g is None:
            # TODO(apassos): figure out a way to summarize why sources and targets are
            # not connected.
            raise ValueError(
                "There is no sequence of operations connecting source "
                "tensor %s (%s) to any of the target Tensors. This is "
                "commonly caused by the tape not recording all "
                "operations in the forward pass or if by mistake a "
                "source was only used in non-differentiable operations." %
                (i, s))
        result.append(_aggregate_grads(g))
    return result
示例#30
0
def _watch_value_from_tape(tensor):
  for t in tape._tape_stack.stack:  # pylint: disable=protected-access
    w = t.value.tensors.get(tf_ops.tensor_id(tensor), None)
    if w is not None:
      return w
  return tensor
示例#31
0
def _watch_with_tape(tape, tensor):
    """Wraps a watched Tensor and keeps track of it in the implicit tape."""
    w = _watch_with_tape_internal(tape, tensor)
    if ag_core.isnode(tape):
        tape.value.tensors[ops.tensor_id(tensor)] = w
    return w
示例#32
0
    def _capture_helper(self, tensor, name):
        if (tensor.graph is not self._forward_graph
                or any(tensor is t for t in self._forward_graph.inputs)
                or any(tensor is t for t in self._forward_graph.outputs)):
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        if control_flow_util.GraphOrParentsInXlaContext(
                ops.get_default_graph()):
            # XLA does not yet support optionals, so capture intermediates directly.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            if all(tensor is not capture
                   for capture in self.external_captures):
                self.xla_intermediates.append(tensor)
                self.op_needs_rewrite = True
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        tensor_id = ops.tensor_id(tensor)
        captured_tensor = self._indirect_captures.get(tensor_id)
        if captured_tensor is not None:
            return captured_tensor

        # 'tensor' is an uncaptured intermediate in the forward graph.
        # If it is not a resource, we wrap it in an optional in the forward graph
        # and capture the optional normally. We then unwrap the captured optional
        # value in the gradient graph to get the raw intermediate value.
        # If it is a resource, we trace the resource upto the input in the forward
        # graph and capture that.

        if tensor.dtype == dtypes.resource:
            # Index of the forward graph input corresponding to the resource tensor.
            index = util.resource_input_index(
                tensor.name, [t.name for t in self._forward_graph.inputs], {
                    op.name: op.node_def
                    for op in self._forward_graph.get_operations()
                }, self._forward_graph._functions)
            # This gets mapped to the corresponding If op input in
            # `_resolve_grad_inputs`.
            captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
                self._forward_graph.inputs[index], name)
        else:
            if tensor_id not in self._wrapped_intermediates:
                # If the gradient has already been computed for this If op, 'tensor' may
                # already be wrapped.
                for consumer in tensor.consumers():
                    if (consumer.type == "OptionalFromValue" and any(
                            consumer.outputs[0] is output
                            for output in self._forward_graph.outputs)):
                        optional = consumer.outputs[0]
                        break
                else:
                    # 'tensor' hasn't been wrapped, do it now.
                    with self._forward_graph.as_default():
                        optional = gen_dataset_ops.optional_from_value(
                            [tensor])
                    self.op_needs_rewrite = True
                self._wrapped_intermediates[tensor_id] = optional

            optional = self._wrapped_intermediates[tensor_id]
            captured_optional = super(_CondGradFuncGraph,
                                      self)._capture_helper(optional, name)
            captured_tensor = gen_dataset_ops.optional_get_value(
                captured_optional, [tensor.dtype], [tensor.shape])[0]

        self._indirect_captures[tensor_id] = captured_tensor
        return captured_tensor
示例#33
0
def _watch_value_from_tape(tensor):
    for t in tape._tape_stack.stack:  # pylint: disable=protected-access
        w = t.value.tensors.get(tf_ops.tensor_id(tensor), None)
        if w is not None:
            return w
    return tensor
示例#34
0
 def capture_distributed_variable(self, variable, placeholder):
   """Add given distributed variable to captures with given placeholder."""
   self._captures[ops.tensor_id(variable)] = (variable, placeholder)
   tape.record_operation("captured_value", [placeholder], [variable],
                         lambda x: [x])
示例#35
0
  def __exit__(self, unused_type, unused_value, unused_traceback):
    # pylint: disable=protected-access
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Within the automatic control dependency context, the default graph"
          f" cannot change. Upon entry it was {self._graph}, but on exit it"
          f" changed to {ops.get_default_graph()}")

    if hasattr(self._graph, "outer_graph"):
      outer_val = self._graph.outer_graph._add_control_dependencies
      self._graph._add_control_dependencies = outer_val
    else:
      self._graph._add_control_dependencies = False

    # map from resource tensor to the last op which wrote to it
    last_write_to_resource = {}
    # map from resource tensor to the list of reads from it since the last
    # write or since the beginning of the function.
    reads_since_last_write_to_resource = collections.defaultdict(list)
    # CollectiveManager manager_ids within a particular function call should not
    # be needed outside of that function call. So we keep them separate (though
    # the general idea of the maps is the same, in the future, we'll need to
    # correctly thread the control output outside).
    # Map from collective manager scope to the last op which used it
    collective_manager_scopes_opened = {}
    collective_manager_scopes_used = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]
    first_use_for_res = {}
    resources_by_op = {}

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_write_to_resource).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if control_flow_util.IsInWhileLoop(op):
        continue
      control_inputs = set()

      # Ensure stateful ops run.
      # Read-only ops are added to control outputs if the read value is
      # consumed. This covers the case when the read value is returned from
      # the function since that goes through a tf.identity in mark_as_return.
      if ((op_def_registry.get(op.type) is None) or
          (op_is_stateful(op) and
           (op.type not in utils.RESOURCE_READ_OPS or
            any(output.consumers() for output in op.outputs))) or
          (op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS)):
        ops_which_must_run.add(op)

      # Make a note of all opened manager_ids.
      if op.type == "NoOp":
        try:
          collective_manager_scopes_opened[op.get_attr(
              "_collective_manager_id")] = op
        except ValueError:
          pass
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      # TODO(mdan): Don't do this. Write a transform to chains instead.
      # See core/common_runtime/control_flow_deps_to_chains.cc.
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)
          for inp in o.inputs:
            input_id = ops.tensor_id(inp)
            if input_id in last_write_to_resource:
              last_write_to_resource[input_id] = op
        ops_which_must_run = set([op])
        continue

      resource_inputs = set()
      # Check for any resource inputs. If we find any, we update control_inputs
      # and last_write_to_resource.
      for inp, resource_type in _get_resource_inputs(op):
        is_read = resource_type == ResourceType.READ_ONLY
        input_id = ops.tensor_id(inp)

        # If the op receives the same resource tensor twice as an input, we skip
        # to avoid the op getting a control dependency on itself.
        if input_id in resource_inputs:
          continue

        resource_inputs.add(input_id)
        # Deal with switches, finally.
        if inp.op.type == "Switch":
          self._process_switch(inp.op, ops_which_must_run,
                               last_write_to_resource, merge_for_resource)
        is_building_function = op.graph.building_function
        # Ensure uses of resources are serialized
        if input_id in last_write_to_resource:
          if is_building_function or (
              last_write_to_resource[input_id]._control_flow_context
              is op._control_flow_context):
            control_inputs.add(last_write_to_resource[input_id])
        # Ensure merges happen after the closing of a cond block
        if input_id in merge_for_resource:
          merge_for_resource[input_id]._add_control_input(op)

        do_record = (
            self.record_initial_resource_uses and
            input_id not in first_use_for_res)

        if is_read:
          reads_list = reads_since_last_write_to_resource[input_id]
          reads_list.append(op)

          if do_record:
            # Note: this will track the entire list that
            # reads_since_last_write_to_resource maintains. Updates to it will
            # and should be tracked, until the first write is encountered. At
            # that point, reads_since_last_write_to_resource will contain a new
            # empty list. This logic relies on that behavior.
            first_use_for_res[input_id] = reads_list

        else:
          control_inputs.update(reads_since_last_write_to_resource[input_id])
          reads_since_last_write_to_resource[input_id] = []
          last_write_to_resource[input_id] = op

          if do_record:
            first_use_for_res[input_id] = [op]

      if self.record_initial_resource_uses and op_is_stateful(op):
        if resource_inputs:
          resources_by_op[op] = tuple(resource_inputs)
        else:
          if None not in first_use_for_res:
            first_use_for_res[None] = [op]
          resources_by_op[op] = (None,)

      if (op_is_stateful(op) and not resource_inputs
          and op._control_flow_context is None):
        if None in last_write_to_resource:
          op._add_control_input(last_write_to_resource[None])
        last_write_to_resource[None] = op

      # Ensure ordering of collective ops
      manager_ids = collective_manager_ids_from_op(op)
      for manager_id in manager_ids:
        if manager_id in collective_manager_scopes_opened:
          # Chain this function call if the scope was opened.
          op._add_control_input(collective_manager_scopes_opened[manager_id])
          collective_manager_scopes_opened[manager_id] = op
        else:
          # If this op is in a scope not created here, create a chain starting
          # at this op.
          if manager_id in collective_manager_scopes_used:
            op._add_control_input(collective_manager_scopes_used[manager_id])
          collective_manager_scopes_used[manager_id] = op

      if control_inputs and not is_building_function:
        control_inputs = [
            c for c in control_inputs
            if c._control_flow_context is op._control_flow_context
        ]

      op._add_control_inputs(control_inputs)

    # Record the ops which first use resources touched by "ops which must run".
    if self.record_initial_resource_uses:
      first_uses_by_output_ops = {}
      for op in ops_which_must_run:
        if op not in resources_by_op:
          # This may happen with Merge/Switch nodes which are special cased
          # above.
          continue
        for r in resources_by_op[op]:
          if op not in first_uses_by_output_ops:
            first_uses_by_output_ops[op] = set()
          first_uses_by_output_ops[op].update(first_use_for_res[r])
      # For each "op which must run", set a private attr indicating the ops that
      # used the same resources it did.
      for op in first_uses_by_output_ops:
        others = [
            other.name.encode() for other in first_uses_by_output_ops[op]
        ]
        l = attr_value_pb2.AttrValue.ListValue(s=others)
        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))

    # Ensure all ops which must run do run
    self.ops_which_must_run.update(ops_which_must_run)
    control_output_op = None
    for idx, r in enumerate(
        nest.flatten(list(self._returned_tensors), expand_composites=True)):
      if self.ops_which_must_run:
        updated_ops_which_must_run = []
        if r.graph.building_function:
          # There may be many stateful ops in the graph. Adding them as
          # control inputs to each function output could create excessive
          # control edges in the graph. Thus we create an intermediate No-op
          # to chain the control dependencies between stateful ops and
          # function outputs.
          if idx == 0:
            control_output_op = control_flow_ops.no_op()
            control_output_op._set_attr("_acd_function_control_output",
                                        attr_value_pb2.AttrValue(b=True))
            control_output_op._add_control_inputs(self.ops_which_must_run)
          updated_ops_which_must_run = [control_output_op]
        else:
          updated_ops_which_must_run = [
              o for o in self.ops_which_must_run
              if o._control_flow_context is r.op._control_flow_context
          ]
        r.op._add_control_inputs(updated_ops_which_must_run)

    self.collective_manager_ids_used = collective_manager_scopes_used
示例#36
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        # pylint: disable=protected-access
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Graph changed while trying to add control dependencies.")

        if hasattr(self._graph, "outer_graph"):
            outer_val = self._graph.outer_graph._add_control_dependencies
            self._graph._add_control_dependencies = outer_val
        else:
            self._graph._add_control_dependencies = False

        # map from resource tensor to the last op which wrote to it
        last_write_to_resource = {}
        # map from resource tensor to the list of reads from it since the last
        # write or since the beginning of the function.
        reads_since_last_write_to_resource = collections.defaultdict(list)
        # CollectiveManager manager_ids within a particular function call should not
        # be needed outside of that function call. So we keep them separate (though
        # the general idea of the maps is the same, in the future, we'll need to
        # correctly thread the control output outside).
        # Map from collective manager scope to the last op which used it
        collective_manager_scopes_opened = {}
        collective_manager_scopes_used = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_write_to_resource).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()
            # Ensure stateful ops run
            if (op_def_registry.get(op.type) is None
                    or (op_is_stateful(op)
                        and op.type not in utils.RESOURCE_READ_OPS)):
                # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
                # they only have variable reads and are otherwise stateless.
                ops_which_must_run.add(op)
            # Make a note of all opened manager_ids.
            if op.type == "NoOp":
                try:
                    collective_manager_scopes_opened[op.get_attr(
                        "_collective_manager_id")] = op
                except ValueError:
                    pass
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)
                    for inp in o.inputs:
                        input_id = ops.tensor_id(inp)
                        if input_id in last_write_to_resource:
                            last_write_to_resource[input_id] = op
                ops_which_must_run = set([op])
                continue

            resource_inputs = set()
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_write_to_resource.
            for inp, resource_type in _get_resource_inputs(op):
                is_read = resource_type == ResourceType.READ_ONLY
                input_id = ops.tensor_id(inp)

                # If the op receives the same resource tensor twice as an input, we skip
                # to avoid the op getting a control dependency on itself.
                if input_id in resource_inputs:
                    continue

                resource_inputs.add(input_id)
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_write_to_resource,
                                         merge_for_resource)
                is_building_function = op.graph.building_function
                # Ensure uses of resources are serialized
                if input_id in last_write_to_resource:
                    if is_building_function or (
                            last_write_to_resource[input_id].
                            _control_flow_context is op._control_flow_context):
                        control_inputs.add(last_write_to_resource[input_id])
                # Ensure merges happen after the closing of a cond block
                if input_id in merge_for_resource:
                    merge_for_resource[input_id]._add_control_input(op)
                if is_read:
                    reads_since_last_write_to_resource[input_id].append(op)
                else:
                    control_inputs.update(
                        reads_since_last_write_to_resource[input_id])
                    reads_since_last_write_to_resource[input_id] = []
                    last_write_to_resource[input_id] = op

            if (op_is_stateful(op) and not resource_inputs
                    and op._control_flow_context is None):
                if None in last_write_to_resource:
                    op._add_control_input(last_write_to_resource[None])
                last_write_to_resource[None] = op

            # Ensure ordering of collective ops
            manager_ids = collective_manager_ids_from_op(op)
            for manager_id in manager_ids:
                if manager_id in collective_manager_scopes_opened:
                    # Chain this function call if the scope was opened.
                    op._add_control_input(
                        collective_manager_scopes_opened[manager_id])
                    collective_manager_scopes_opened[manager_id] = op
                else:
                    # If this op is in a scope not created here, create a chain starting
                    # at this op.
                    if manager_id in collective_manager_scopes_used:
                        op._add_control_input(
                            collective_manager_scopes_used[manager_id])
                    collective_manager_scopes_used[manager_id] = op

            if control_inputs and not is_building_function:
                control_inputs = [
                    c for c in control_inputs
                    if c._control_flow_context is op._control_flow_context
                ]

            op._add_control_inputs(control_inputs)

        # Ensure all ops which must run do run
        self.ops_which_must_run.update(ops_which_must_run)
        for r in nest.flatten(list(self._returned_tensors),
                              expand_composites=True):
            if self.ops_which_must_run:
                updated_ops_which_must_run = []
                if r.graph.building_function:
                    updated_ops_which_must_run = self.ops_which_must_run
                else:
                    updated_ops_which_must_run = [
                        o for o in self.ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ]
                r.op._add_control_inputs(updated_ops_which_must_run)

        self.collective_manager_ids_used = collective_manager_scopes_used
示例#37
0
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
                      maximum_iterations):
  """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grads: The incoming grads for `ys`.
    cond_graph: FuncGraph for the forward cond function.
    body_graph: FuncGraph for the forward body function.
    name: Name of the returned gradient function.
    while_op: The forward While op.
    maximum_iterations: Tensor. The maximum number of iterations.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
  assert len(ys) == len(grads)

  total_iters = while_op.outputs[0]
  counter = constant_op.constant(
      0, dtype=total_iters.dtype, name="grad_counter")

  # Build frozen sets so that we do not have linear time lookups in
  # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
  # may get updated during gradient computation because we add accumulators to
  # the forward op. However, those are not loop invariants so wouldn't affect
  # the output of `_is_loop_invariant`. Also we would never attempt to capture
  # those accumulators so `_is_loop_invariant` should never receive those new
  # tensors as args.
  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)

  args = [counter, maximum_iterations, total_iters] + list(grads)
  # Note: The returned function does not have `args` in the list of
  # `external_captures`.
  grad_func_graph = func_graph_module.func_graph_from_py_func(
      name,
      lambda *args: _grad_fn(ys, xs, args, body_graph),
      args, {},
      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
                                         maximum_iterations, while_op,
                                         body_graph_inputs, body_graph_outputs))

  # Update the list of outputs with tensors corresponding to the captured
  # tensors. We capture 3 types of tensors when building the grad fn:
  # 1. Accumulators for forward graph intermediates which are not loop
  #    invariants. The outputs corresponding to these are populated in
  #    `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
  # 2. Resources, which are output as is.
  # 3. Forward graph loop invariants, which are output as is.
  for external_capture, internal_capture in grad_func_graph.captures:
    if ops.tensor_id(internal_capture) in grad_func_graph.popped_tensor_lists:
      new_output = grad_func_graph.popped_tensor_lists[ops.tensor_id(
          internal_capture)]
    elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
        external_capture, body_graph_inputs, body_graph_outputs)):
      new_output = internal_capture
    else:
      raise ValueError("Tensor %s which captures %s is in list of "
                       "internal_captures but is not a resource, is not in "
                       "popped_tensor_lists and does not capture a loop "
                       "invariant." %
                       (str(internal_capture), str(external_capture)))
    grad_func_graph.outputs.append(new_output)
    grad_func_graph.structured_outputs.append(new_output)

  return grad_func_graph, args
示例#38
0
 def reset_captures(self, capture_list):
   """Set the captures with the provided list of captures & placeholder."""
   self._captures = py_collections.OrderedDict()
   for tensor, placeholder in capture_list:
     self._captures[ops.tensor_id(tensor)] = (tensor, placeholder)
示例#39
0
 def mut_add(implicit_tape):
     resource_variable = tape.value.variables[ops.tensor_id(tensor)]
     implicit_tape.gradients.append((g, resource_variable))
     return implicit_tape
示例#40
0
 def mut_add(implicit_tape):
   resource_variable = tape.value.variables[ops.tensor_id(tensor)]
   implicit_tape.gradients.append((g, resource_variable))
   return implicit_tape
示例#41
0
def any_tape_has(tensor):
  for t in _tape_stack.stack:
    if ops.tensor_id(tensor) in t.value.tensors:
      return True
  return False
示例#42
0
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
    if captured_tensor is not None:
      return captured_tensor

    # Do not accumulate loop invariants.
    if (any(tensor is t for t in self._forward_graph.inputs) and
        any(tensor is t for t in self._forward_graph.outputs)):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      # Add to `popped_tensor_lists` so that this gets added to the list of
      # outputs.
      # TODO(srbs): Rename popped_tensor_lists.
      self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
      return captured_tensor

    # Do not accumulate Const nodes. Instead copy them directly in the backward
    # graph.
    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
    # graph compile time consts in general.
    # TODO(srbs): Consider making this a loop input.
    if constant_op.is_constant(tensor):
      real_value = constant_op.constant(
          tensor_util.constant_value(tensor), dtype=tensor.dtype)
      self._indirect_captures[ops.tensor_id(tensor)] = real_value
      return real_value

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # No need to accumulate loop invariants. Capture them directly.
    # The captured tensor gets resolved to the corresponding while output in
    # `_resolve_grad_captures`.
    if _is_loop_invariant(tensor, self._forward_graph_inputs,
                          self._forward_graph_outputs):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      return captured_tensor

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      #
      # Note: We clear the control dependencies to avoid a cycle in case a
      # control tensor has an input path to an output of the  forward While.
      #
      # E.g.:
      # x = tf.while_loop(...)
      # y = f(x)
      # with tf.control_dependencies([y]):
      #   tf.gradients(y, x)
      #
      # Since the EmptyTensorList is fed back into the forward While, not
      # removing the control edge would cause a cycle.
      with self._forward_graph.outer_graph.as_default():
        with util.clear_control_inputs():
          tensor_list = list_ops.empty_tensor_list(
              element_dtype=tensor.dtype,
              element_shape=tensor.shape,
              max_num_elements=self._maximum_iterations,
              name=_build_accumulator_name(tensor))
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
    self.popped_tensor_lists[ops.tensor_id(
        captured_accumulator)] = new_tensor_list
    return captured_tensor
示例#43
0
def _watch_with_tape(tape, tensor):
  """Wraps a watched Tensor and keeps track of it in the implicit tape."""
  w = _watch_with_tape_internal(tape, tensor)
  if ag_core.isnode(tape):
    tape.value.tensors[ops.tensor_id(tensor)] = w
  return w
示例#44
0
def any_tape_has(tensor):
    for t in _tape_stack.stack:
        if ops.tensor_id(tensor) in t.value.tensors:
            return True
    return False
示例#45
0
def imperative_grad(
    target,
    sources,
    output_gradients=None):
  """Computes gradients from the imperatively defined tape on top of the stack.

  Works by filtering the tape, computing how many downstream usages are of each
  tensor and entry, and repeatedly applying backward functions until we have
  gradients for all sources.

  Args:
   target: either a Tensor or list of Tensors to be differentiated.
   sources: list of Tensors for which we want gradients
   output_gradients: if not None, a list of gradient provided for each Target,
    or None if we are to use the target's computed downstream gradient.

  Returns:
   the gradient wrt each of the sources.

  Raises:
    RuntimeError: if something goes wrong.
    ValueError: if there is no sequence of differentiable operations connecting
     a source and any target Tensor. This can happen either if the target is
     not computed based on the source, if the tracing was set up incorrectly,
     or if only non-differentiable functions of the source were used in the
     computation of target.
  """
  if not tape._tape_stack.stack:  # pylint: disable=protected-access
    raise RuntimeError("Computing a gradient with no tape present")
  bp_tape = tape.pop_tape()
  tensor_to_op, op_to_entry = bp_tape.export()
  # This overwrites the op_to_entry variable, which will release all memory used
  # to keep traces that are irrelevant to the gradient computation we're doing
  # here.
  id_sources = [ops.tensor_id(t) for t in sources]
  tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
      target, tensor_to_op, op_to_entry, id_sources)
  ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
  gradients = _initial_gradients(target, output_gradients,
                                 tensor_usage_counts)
  gradients_size = dict()
  # Now exhaust the backprop stack
  while ready_ops:
    op = ready_ops.pop()
    op_trace = op_to_entry.pop(op)
    out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
    for i in range(len(out_gradients)):
      if out_gradients[i] is None:
        # TODO(apassos) this should be in the right device
        none_indices = _grad_fn_accepts_none_for_indices.get(
            op_trace.op_type, None)
        if none_indices is None or i not in none_indices:
          out_gradients[i] = array_ops.zeros(
              *op_trace.output_shape_and_dtype[i])
      else:
        out_gradients[i] = _aggregate_grads(out_gradients[i])

    in_gradients = op_trace.backward_function(
        *(out_gradients + op_trace.side_outputs))
    in_gradients = ([in_gradients]
                    if isinstance(in_gradients, (ops.Tensor,
                                                 ops.IndexedSlices,
                                                 type(None)))
                    else in_gradients)
    for i, t in enumerate(op_trace.input_ids):
      if in_gradients[i] is not None:
        _add_new_grads(gradients, gradients_size, t, in_gradients[i])
      if tensor_usage_counts.get(t, 0) > 0:
        tensor_usage_counts[t] -= 1
        if (t in tensor_to_op
            and tensor_usage_counts[t] == 0
            and t not in id_sources):
          in_op = tensor_to_op[t]
          if in_op is None:
            continue
          if op_missing_tensor.get(in_op, 0) > 0:
            op_missing_tensor[in_op] -= 1
            if op_missing_tensor.get(in_op, 0) == 0:
              ready_ops.append(in_op)
  result = []
  for i, s in enumerate(sources):
    g = gradients.get(ops.tensor_id(s), None)
    if g is None:
      # TODO(apassos): figure out a way to summarize why sources and targets are
      # not connected.
      raise ValueError("There is no sequence of operations connecting source "
                       "tensor %s (%s) to any of the target Tensors. This is "
                       "commonly caused by the tape not recording all "
                       "operations in the forward pass or if by mistake a "
                       "source was only used in non-differentiable operations."
                       % (i, s))
    result.append(_aggregate_grads(g))
  return result