示例#1
0
  def initialize(self, table):
    """Initializes the table from a text file.

    Args:
      table: The table to be initialized.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    _check_table_dtypes(table, self.key_dtype, self.value_dtype)
    with ops.name_scope(self._name, "text_file_init",
                        (table.table_ref,)) as scope:
      filename = ops.convert_to_tensor(
          self._filename, dtypes.string, name="asset_filepath")
      # pylint: disable=protected-access
      init_op = gen_lookup_ops._initialize_table_from_text_file_v2(
          table.table_ref,
          filename,
          self._key_index,
          self._value_index,
          -1 if self._vocab_size is None else self._vocab_size,
          self._delimiter,
          name=scope)
      # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    # If the filename tensor is anything other than a string constant (e.g., if
    # it is a placeholder) then it does not make sense to track it as an asset.
    if context.in_graph_mode() and constant_op.is_constant(filename):
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
    return init_op
示例#2
0
  def initialize(self, table):
    """Initializes the table from a text file.

    Args:
      table: The table to be initialized.

    Returns:
      The operation that initializes the table.

    Raises:
      TypeError: when the keys and values data types do not match the table
      key and value data types.
    """
    _check_table_dtypes(table, self.key_dtype, self.value_dtype)
    with ops.name_scope(self._name, "text_file_init",
                        (table.table_ref,)) as scope:
      filename = ops.convert_to_tensor(
          self._filename, dtypes.string, name="asset_filepath")
      # pylint: disable=protected-access
      init_op = gen_lookup_ops._initialize_table_from_text_file_v2(
          table.table_ref,
          filename,
          self._key_index,
          self._value_index,
          -1 if self._vocab_size is None else self._vocab_size,
          self._delimiter,
          name=scope)
      # pylint: enable=protected-access
    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
    # If the filename tensor is anything other than a string constant (e.g., if
    # it is a placeholder) then it does not make sense to track it as an asset.
    if context.in_graph_mode() and constant_op.is_constant(filename):
      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
    return init_op
示例#3
0
文件: tfutil.py 项目: shawwn/tfimg
def to_np(value, eager=False, session=None, deep=True):
  if isinstance(value, tf.TensorShape):
    value = value.as_list()
  if isinstance(value, tf.Dimension):
    value = value.value
  if torch:
    if isinstance(value, torch.Size):
      value = list(value)
  #if is_tf(value, strict=True) and not pyobj(value) and constant_op.is_constant(value):
  if hasattr(value, 'type') and constant_op.is_constant(value):
    value = tensor_util.constant_value(value)
  if deep and pylist(value):
    value = [to_np(x, eager=eager, session=session, deep=deep) for x in value]
  if deep and pydict(value):
    value = {k: to_np(v, eager=eager, session=session, deep=deep) for k, v in value.items()}
  if is_tf_variable(value):
    if not eager:
      raise ValueError("to_np called on tensorflow variable {} but eager is False".format(value))
    # TODO: batch multiple nested reads.
    result = value.eval(session=session)
    return to_np(result, eager=eager, session=session, deep=deep)
  if is_torch_tensor(value):
    if not eager:
      raise ValueError("to_np called on torch tensor {} but eager is False".format(value))
    result = value.numpy()
    return to_np(result, eager=eager, session=session, deep=deep)
  assert is_np(value)
  return value
    def GetRealValue(self, value):
        """Get the real value of `value`.

    If backprop "uses" a value produced by forward inference, an accumulator
    is added in the forward loop to accumulate its values.  We use the
    accumulated value. This method must be called in the grad loop context.
    `value` must be in forward and needed for backprop.

    Args:
      value: A tensor to be captured.

    Returns:
      The same tensor obtained from the saved history.
    """
        assert value.op.type not in ["Variable", "VariableV2"]
        real_value = self._history_map.get(value.name)
        if real_value is None:
            cur_value = value
            cur_grad_state = self
            while True:
                enter_op = util.GetLoopConstantEnter(cur_value)
                if enter_op:
                    # Special case: cur_value comes from a constant Enter node.
                    cur_value = enter_op.inputs[0]
                    cur_grad_state = cur_grad_state.outer_grad_state
                    if cur_grad_state is None:
                        # We are now outside all nested loops for this gradient(),
                        # so `value` is a loop invariant and there is no need to
                        # save the history of value. Just make cur_value to enter
                        # the right control flow context.
                        real_value = self._grad_context.AddValue(cur_value)
                        break
                elif constant_op.is_constant(cur_value):
                    # If the value to be forwarded is a constant, clone the constant in
                    # the gradient loop rather than using a stack.
                    # TODO(phawkins): consider hoisting the constant out of the loop
                    # instead.
                    real_value = constant_op.constant(
                        tensor_util.constant_value(cur_value),
                        dtype=cur_value.dtype)
                    break
                else:
                    # Record the history of this value in forward_ctxt.
                    self._grad_context.Exit()
                    history_value = cur_grad_state.AddForwardAccumulator(
                        cur_value)
                    self._grad_context.Enter()
                    break

            if real_value is None:
                # Add the stack pop op in the grad context.
                real_value = cur_grad_state.AddBackpropAccumulatedValue(
                    history_value, cur_value)
                if cur_grad_state != self:
                    real_value = self._grad_context.AddValue(real_value)
            self._history_map[value.name] = real_value
        return real_value
示例#5
0
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
    """Gradient for concat op.

  Args:
    op: An operation.
    grad: `Tensor` or `IndexedSlices` representing the gradients with respect
      to each output of the op.
    start_value_index: An integer index of the first value in the op.inputs.
    end_value_index: An integer index of the last value in the op.inputs.
    dim_index: An interger index of concat_dim or axis parameter in op.inputs.

  Returns:
    Tensors representing the partial gradients with respect to each input
    of the op.

  Raises:
    ValueError: if concat_dim/axis is not statically known.
  """
    def _CreateDenseMaskAndBegin(sizes, concat_dim):
        """Create variables for iteratively slicing a dense gradients tensor."""
        # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
        shape_of_shape = array_ops.shape(sizes[0])
        # Make a vector of length equal to the input's dimensions,
        # with 0's everywhere and 1 in the concat dim position.
        # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
        mask = array_ops.concat([
            array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
            array_ops.fill(shape_of_shape - concat_dim - 1, 0)
        ], 0)
        begin = array_ops.fill(shape_of_shape, 0)
        return mask, begin

    def _ExtractInputShapes(inputs):
        """Extract the shapes of a set of input tensors."""
        if context.executing_eagerly():
            return array_ops.shape_n(inputs)
        sizes = []
        fully_known = True
        for x in inputs:
            input_shape = array_ops.shape(x)
            if not isinstance(input_shape,
                              ops.Tensor) or input_shape.op.type != "Const":
                fully_known = False
                break
            sizes.append(input_shape)

        if fully_known:
            return sizes
        else:
            return array_ops.shape_n(inputs)

    # Degenerate concatenation, just return grad.
    if len(op.inputs) == 2:
        return grad + [None] if end_value_index <= dim_index else [None] + grad

    concat_dim = op.inputs[dim_index]
    input_values = op.inputs[start_value_index:end_value_index]

    out_grads = []
    if isinstance(grad, ops.Tensor):
        if context.executing_eagerly():
            # Using mod here for convenience since concat_dim is already verified
            # in concat implementation to be within the allowed [-rank, rank) range.
            non_neg_concat_dim = (concat_dim._numpy().item(0) %
                                  input_values[0]._rank())  # pylint: disable=protected-access
            # All inputs are guaranteed to be EagerTensors in eager mode
            sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(
                input_values, non_neg_concat_dim)
            out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
        else:
            if constant_op.is_constant(concat_dim):
                # If concat_dim is a constant defined in a different context,
                # then we duplicate it in the current context to avoid passing it
                # through an Enter node.
                # This is a small optimization in general, but it is required when
                # compiling with XLA, as XLA needs the concat input to be folded into a
                # constant.
                grad_context = control_flow_util.GetOutputContext(grad.op)
                dim_context = control_flow_util.GetOutputContext(concat_dim.op)
                if dim_context != grad_context:
                    value = tensor_util.constant_value(concat_dim)
                    concat_dim = constant_op.constant(value=value,
                                                      dtype=concat_dim.dtype)

            # Using mod here for convenience since concat_dim is already verified
            # in concat implementation to be within the allowed [-rank, rank) range.
            non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])

            # Get the inputs' tensor shapes
            sizes = _ExtractInputShapes(input_values)
            # The magic number of 16 was found through benchmarking a range of sizes
            # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
            # cases when switching implementations at N=16, but it is possible that
            # there will be a small number of performance regressions.
            if len(sizes) > 16:
                # extract the size of each input along the concat dimension
                sizes = array_ops.squeeze(
                    array_ops.slice(array_ops.stack(sizes, axis=1),
                                    [non_neg_concat_dim, 0], [1, -1]))
                out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
            else:
                offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
                for (begin, size) in zip(offset, sizes):
                    out_grads.append(array_ops.slice(grad, begin, size))
    elif isinstance(grad, ops.IndexedSlices):
        # Using mod here for convenience since concat_dim is already verified
        # in concat implementation to be within the allowed [-rank, rank) range.
        non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
        concat_dim_static = tensor_util.constant_value(concat_dim)
        if concat_dim_static is None:
            raise ValueError("Can only compute IndexedSlices gradient with "
                             "statically-known concat_dim")
        if concat_dim_static < 0:
            rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
            if rank is None:
                raise ValueError(
                    "Can only compute IndexedSlices gradient with "
                    "negative concat_dim when first value rank is "
                    "statically-known.")
            concat_dim_static %= rank
        # Get the inputs' tensor shapes
        sizes = [array_ops.shape(x) for x in input_values]
        if concat_dim_static > 0:
            # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
            # gradients with all the indices, but with grad.values sliced accordingly.
            # This is like the Tensor case, except shape(grad.values)[0] is not equal
            # to shape(sizes[i])[0], since only a subset of the dim-0 values are
            # stored.
            mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
            for size in sizes:
                new_values = array_ops.slice(
                    grad.values, begin,
                    array_ops.concat(
                        [[-1], array_ops.slice(size, [1], [-1])], 0))
                out_grads.append(
                    ops.IndexedSlices(new_values, grad.indices, size))
                # Lint complains begin = begin + ...
                begin = math_ops.add(begin, size * mask)
        else:
            # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
            # only for the relevant indices.
            start = constant_op.constant(0, dtype=grad.indices.dtype)
            for size in sizes:
                size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
                if size_concat_dim.dtype != grad.indices.dtype:
                    size_concat_dim = math_ops.cast(size_concat_dim,
                                                    dtype=grad.indices.dtype)
                end = start + size_concat_dim
                # Compute the 1-D Tensor of indices relevant for this input.
                indices_to_select = array_ops.squeeze(array_ops.where(
                    math_ops.logical_and(grad.indices >= start,
                                         grad.indices < end)),
                                                      axis=[1])
                new_indices = array_ops.gather(grad.indices,
                                               indices_to_select) - start
                new_values = array_ops.gather(grad.values, indices_to_select)
                out_grads.append(
                    ops.IndexedSlices(new_values, new_indices, size))
                start = end
    else:
        raise TypeError("Expected Tensor or IndexedSlices, got %s" %
                        type(grad))

    return (out_grads + [None] if end_value_index <= dim_index else [None] +
            out_grads)
示例#6
0
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
    if captured_tensor is not None:
      return captured_tensor

    # Do not accumulate loop invariants.
    if (any(tensor is t for t in self._forward_graph.inputs) and
        any(tensor is t for t in self._forward_graph.outputs)):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      # Add to `popped_tensor_lists` so that this gets added to the list of
      # outputs.
      # TODO(srbs): Rename popped_tensor_lists.
      self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
      return captured_tensor

    # Do not accumulate Const nodes. Instead copy them directly in the backward
    # graph.
    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
    # graph compile time consts in general.
    # TODO(srbs): Consider making this a loop input.
    if constant_op.is_constant(tensor):
      real_value = constant_op.constant(
          tensor_util.constant_value(tensor), dtype=tensor.dtype)
      self._indirect_captures[ops.tensor_id(tensor)] = real_value
      return real_value

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # No need to accumulate loop invariants. Capture them directly.
    # The captured tensor gets resolved to the corresponding while output in
    # `_resolve_grad_captures`.
    if _is_loop_invariant(tensor, self._forward_graph_inputs,
                          self._forward_graph_outputs):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      return captured_tensor

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      #
      # Note: We clear the control dependencies to avoid a cycle in case a
      # control tensor has an input path to an output of the  forward While.
      #
      # E.g.:
      # x = tf.while_loop(...)
      # y = f(x)
      # with tf.control_dependencies([y]):
      #   tf.gradients(y, x)
      #
      # Since the EmptyTensorList is fed back into the forward While, not
      # removing the control edge would cause a cycle.
      with self._forward_graph.outer_graph.as_default():
        with util.clear_control_inputs():
          tensor_list = list_ops.empty_tensor_list(
              element_dtype=tensor.dtype,
              element_shape=tensor.shape,
              max_num_elements=self._maximum_iterations,
              name=_build_accumulator_name(tensor))
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
    self.popped_tensor_lists[ops.tensor_id(
        captured_accumulator)] = new_tensor_list
    return captured_tensor
示例#7
0
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
  """Gradient for concat op.

  Args:
    op: An operation.
    grad: `Tensor` or `IndexedSlices` representing the gradients with respect
      to each output of the op.
    start_value_index: An integer index of the first value in the op.inputs.
    end_value_index: An integer index of the last value in the op.inputs.
    dim_index: An interger index of concat_dim or axis parameter in op.inputs.

  Returns:
    Tensors representing the partial gradients with respect to each input
    of the op.

  Raises:
    ValueError: if concat_dim/axis is not statically known.
  """

  def _CreateDenseMaskAndBegin(sizes, concat_dim):
    """Create variables for iteratively slicing a dense gradients tensor."""
    # Since shape is 1-D, shape_of_shape = [rank-of-inputs]
    shape_of_shape = array_ops.shape(sizes[0])
    # Make a vector of length equal to the input's dimensions,
    # with 0's everywhere and 1 in the concat dim position.
    # Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
    mask = array_ops.concat([
        array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
        array_ops.fill(shape_of_shape - concat_dim - 1, 0)
    ], 0)
    begin = array_ops.fill(shape_of_shape, 0)
    return mask, begin

  def _ExtractInputShapes(inputs):
    """Extract the shapes of a set of input tensors."""
    if context.executing_eagerly():
      return array_ops.shape_n(inputs)
    sizes = []
    fully_known = True
    for x in inputs:
      input_shape = array_ops.shape(x)
      if not isinstance(input_shape,
                        ops.Tensor) or input_shape.op.type != "Const":
        fully_known = False
        break
      sizes.append(input_shape)

    if fully_known:
      return sizes
    else:
      return array_ops.shape_n(inputs)

  # Degenerate concatenation, just return grad.
  if len(op.inputs) == 2:
    return grad + [None] if end_value_index <= dim_index else [None] + grad

  concat_dim = op.inputs[dim_index]
  input_values = op.inputs[start_value_index:end_value_index]

  out_grads = []
  if isinstance(grad, ops.Tensor):
    if context.executing_eagerly():
      # Using mod here for convenience since concat_dim is already verified
      # in concat implementation to be within the allowed [-rank, rank) range.
      non_neg_concat_dim = (
          concat_dim._numpy().item(0) % input_values[0]._rank())  # pylint: disable=protected-access
      # All inputs are guaranteed to be EagerTensors in eager mode
      sizes = pywrap_tensorflow.TFE_Py_TensorShapeSlice(input_values,
                                                        non_neg_concat_dim)
      out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
    else:
      if constant_op.is_constant(concat_dim):
        # If concat_dim is a constant defined in a different context,
        # then we duplicate it in the current context to avoid passing it
        # through an Enter node.
        # This is a small optimization in general, but it is required when
        # compiling with XLA, as XLA needs the concat input to be folded into a
        # constant.
        grad_context = control_flow_util.GetOutputContext(grad.op)
        dim_context = control_flow_util.GetOutputContext(concat_dim.op)
        if dim_context != grad_context:
          value = tensor_util.constant_value(concat_dim)
          concat_dim = constant_op.constant(value=value, dtype=concat_dim.dtype)

      # Using mod here for convenience since concat_dim is already verified
      # in concat implementation to be within the allowed [-rank, rank) range.
      non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])

      # Get the inputs' tensor shapes
      sizes = _ExtractInputShapes(input_values)
      # The magic number of 16 was found through benchmarking a range of sizes
      # on CPUs and a Maxwell TitanX.  A speedup was seen in a large majority of
      # cases when switching implementations at N=16, but it is possible that
      # there will be a small number of performance regressions.
      if len(sizes) > 16:
        # extract the size of each input along the concat dimension
        sizes = array_ops.squeeze(
            array_ops.slice(
                array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
                [1, -1]))
        out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
      else:
        offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes)
        for (begin, size) in zip(offset, sizes):
          out_grads.append(array_ops.slice(grad, begin, size))
  elif isinstance(grad, ops.IndexedSlices):
    # Using mod here for convenience since concat_dim is already verified
    # in concat implementation to be within the allowed [-rank, rank) range.
    non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0])
    concat_dim_static = tensor_util.constant_value(concat_dim)
    if concat_dim_static is None:
      raise ValueError("Can only compute IndexedSlices gradient with "
                       "statically-known concat_dim")
    if concat_dim_static < 0:
      rank = tensor_util.constant_value(array_ops.rank(input_values[0]))
      if rank is None:
        raise ValueError("Can only compute IndexedSlices gradient with "
                         "negative concat_dim when first value rank is "
                         "statically-known.")
      concat_dim_static %= rank
    # Get the inputs' tensor shapes
    sizes = [array_ops.shape(x) for x in input_values]
    if concat_dim_static > 0:
      # IndexedSlices, non_neg_concat_dim > 0. Each input gets IndexedSlices
      # gradients with all the indices, but with grad.values sliced accordingly.
      # This is like the Tensor case, except shape(grad.values)[0] is not equal
      # to shape(sizes[i])[0], since only a subset of the dim-0 values are
      # stored.
      mask, begin = _CreateDenseMaskAndBegin(sizes, non_neg_concat_dim)
      for size in sizes:
        new_values = array_ops.slice(
            grad.values, begin,
            array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
        out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
        # Lint complains begin = begin + ...
        begin = math_ops.add(begin, size * mask)
    else:
      # IndexedSlices, concat_dim == 0. Each input gets IndexedSlices gradients
      # only for the relevant indices.
      start = constant_op.constant(0, dtype=grad.indices.dtype)
      for size in sizes:
        size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
        if size_concat_dim.dtype != grad.indices.dtype:
          size_concat_dim = math_ops.cast(
              size_concat_dim, dtype=grad.indices.dtype)
        end = start + size_concat_dim
        # Compute the 1-D Tensor of indices relevant for this input.
        indices_to_select = array_ops.squeeze(
            array_ops.where(
                math_ops.logical_and(grad.indices >= start,
                                     grad.indices < end)),
            axis=[1])
        new_indices = array_ops.gather(grad.indices, indices_to_select) - start
        new_values = array_ops.gather(grad.values, indices_to_select)
        out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
        start = end
  else:
    raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))

  return (out_grads + [None]
          if end_value_index <= dim_index else [None] + out_grads)
示例#8
0
 def checker_dis_fn(inputs, _):
     """Discriminator that checks that it only sees pooled Tensors."""
     self.assertFalse(constant_op.is_constant(inputs))
     return inputs
示例#9
0
 def checker_dis_fn(inputs, _):
   """Discriminator that checks that it only sees pooled Tensors."""
   self.assertFalse(constant_op.is_constant(inputs))
   return inputs
示例#10
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        any(tensor is t for t in self._forward_graph.inputs) or
        any(tensor is t for t in self._forward_graph.outputs)):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    tensor_id = ops.tensor_id(tensor)

    # If `tensor` is a graph-building time constant, we create a constant with
    # the same value in the backward graph instead of capturing it.
    if tensor_id in self._captured_constants:
      return self._captured_constants[tensor_id]
    elif constant_op.is_constant(tensor):
      self._captured_constants[tensor_id] = constant_op.constant(
          tensor_util.constant_value(tensor), dtype=tensor.dtype)
      return self._captured_constants[tensor_id]

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if all(tensor is not capture for capture in self.external_captures):
        self.xla_intermediates.append(tensor)
        self.op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor_id)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor_id not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              any(consumer.outputs[0] is output
                  for output in self._forward_graph.outputs)):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.op_needs_rewrite = True
        self._wrapped_intermediates[tensor_id] = optional

      optional = self._wrapped_intermediates[tensor_id]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor_id] = captured_tensor
    return captured_tensor
示例#11
0
def is_tf_constant(value):
    try:
        return constant_op.is_constant(value)
    except AttributeError:
        # TODO: is there a cleaner way to accomplish this?
        return False