Esempio n. 1
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.if_op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
Esempio n. 2
0
    def _capture_helper(self, tensor, name):
        if (tensor.graph is not self._forward_graph
                or tensor in self._forward_graph.inputs
                or tensor in self._forward_graph.outputs):
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        if control_flow_util.InXlaContext(ops.get_default_graph()):
            # XLA does not yet support optionals, so capture intermediates directly.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            if tensor not in self.captures:
                self.xla_intermediates.append(tensor)
                self.if_op_needs_rewrite = True
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        captured_tensor = self._indirect_captures.get(tensor)
        if captured_tensor is not None:
            return captured_tensor

        # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
        # an optional in the forward graph and capture the optional normally. We
        # then unwrap the captured optional value in the gradient graph to get the
        # raw intermediate value.

        if tensor not in self._wrapped_intermediates:
            # If the gradient has already been computed for this If op, 'tensor' may
            # already be wrapped.
            for consumer in tensor.consumers():
                if (consumer.type == "OptionalFromValue" and
                        consumer.outputs[0] in self._forward_graph.outputs):
                    optional = consumer.outputs[0]
                    break
            else:
                # 'tensor' hasn't been wrapped, do it now.
                with self._forward_graph.as_default():
                    optional = gen_dataset_ops.optional_from_value([tensor])
                self.if_op_needs_rewrite = True

            self._wrapped_intermediates[tensor] = optional

        optional = self._wrapped_intermediates[tensor]
        captured_optional = super(_CondGradFuncGraph,
                                  self)._capture_helper(optional, name)
        captured_tensor = gen_dataset_ops.optional_get_value(
            captured_optional, [tensor.dtype], [tensor.shape])[0]
        self._indirect_captures[tensor] = captured_tensor
        return captured_tensor
Esempio n. 3
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of an If op produced by cond_v2."""
    # Get the if operator (this logic handles the case where op is a MockOp)
    if_op = op.outputs[0].op
    true_graph, false_graph = _get_func_graphs(if_op)
    # Note: op.graph != ops.get_default_graph() when we are computing the gradient
    # of a nested cond.
    assert true_graph.outer_graph == if_op.graph
    assert false_graph.outer_graph == if_op.graph

    # Create grad functions that compute the gradient of the true/false forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    true_grad_graph = _create_grad_func(
        true_graph, grads, util.unique_grad_fn_name(true_graph.name))
    false_grad_graph = _create_grad_func(
        false_graph, grads, util.unique_grad_fn_name(false_graph.name))

    assert ([t.dtype for t in true_grad_graph.outputs
             ] == [t.dtype for t in false_grad_graph.outputs])

    if (true_grad_graph.if_op_needs_rewrite
            or false_grad_graph.if_op_needs_rewrite):
        # Modify 'op' to output the intermediates needed by the grad functions. Note
        # that all needed intermediates are wrapped in optionals. Each optional
        # intermediate output will have a value iff its corresponding branch is
        # taken.
        # NOTE(skyewm): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        if control_flow_util.InXlaContext(ops.get_default_graph()):
            # XLA does not yet support optionals, so output intermediates directly and
            # make them match via FakeParams, which can be converted to zeros in XLA.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            true_intermediates = true_grad_graph.xla_intermediates
            false_intermediates = false_grad_graph.xla_intermediates
            extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
                true_graph, false_graph, true_intermediates,
                false_intermediates)
        else:
            true_intermediates = true_grad_graph.wrapped_intermediates
            false_intermediates = false_grad_graph.wrapped_intermediates
            # Make outputs match by adding none optionals.
            extra_true_outputs, extra_false_outputs = _make_intermediates_match(
                true_graph, false_graph, true_intermediates,
                false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)
        # TODO(skyewm): indicate it's an internal bug if this fails.
        _check_same_outputs(true_graph, false_graph)

        true_graph.name += "_rewritten"
        false_graph.name += "_rewritten"

        if_op._set_func_attr("then_branch",
                             util.create_new_tf_function(true_graph))
        if_op._set_func_attr("else_branch",
                             util.create_new_tf_function(false_graph))
        if_op._set_type_list_attr("Tout", true_graph.output_types)
        if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
        if_op._add_outputs([t.dtype for t in extra_true_outputs],
                           [t.shape for t in extra_true_outputs])

    # Resolve references to forward graph tensors in grad graphs and ensure
    # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
    true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
    false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

    outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
                          true_grad_inputs, false_grad_inputs)

    # The predicate has no gradient.
    return [None] + outputs
Esempio n. 4
0
def _CaseGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a Case op produced by tf.switch_case."""
    # Get the Case operator (this logic handles the case where op is a MockOp)
    case_op = op.outputs[0].op
    branch_graphs = get_func_graphs(case_op)
    assert branch_graphs
    # Note: op.graph != ops.get_default_graph() when we are computing the gradient
    # of a nested cond.
    for branch_graph in branch_graphs:
        assert branch_graph.outer_graph == case_op.graph

    # Create grad functions that compute the gradient of the branch forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    branch_grad_graphs = []
    for branch_graph in branch_graphs:
        branch_grad_graphs.append(
            _create_grad_func(branch_graph, grads,
                              util.unique_grad_fn_name(branch_graph.name)))

    if any(g.op_needs_rewrite for g in branch_grad_graphs):
        # Modify 'op' to output the intermediates needed by the grad functions. Note
        # that all needed intermediates are wrapped in optionals. Each optional
        # intermediate output will have a value iff its corresponding branch is
        # taken.
        # NOTE(bjp): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        if control_flow_util.InXlaContext(ops.get_default_graph()):
            # XLA does not yet support optionals, so output intermediates directly and
            # make them match via FakeParams, which can be converted to zeros in XLA.
            # TODO(bjp,jpienaar): can XLA support optionals?
            branches_intermediates = [
                branch_grad_graph.xla_intermediates
                for branch_grad_graph in branch_grad_graphs
            ]
            extra_branch_outputs = _make_intermediates_match_xla(
                branch_graphs, branches_intermediates)
        else:
            branch_intermediates = [
                g.wrapped_intermediates for g in branch_grad_graphs
            ]
            # Make outputs match by adding none optionals.
            extra_branch_outputs = _make_intermediates_match(
                branch_graphs, branch_intermediates)

        for branch_graph, extra_outputs in zip(branch_graphs,
                                               extra_branch_outputs):
            branch_graph.outputs.extend(extra_outputs)
        # TODO(bjp): indicate it's an internal bug if this fails.
        _check_same_outputs(_CASE, branch_graphs)

        for branch_graph in branch_graphs:
            branch_graph.name += "_rewritten"

        case_op._set_func_list_attr("branches", [
            util.create_new_tf_function(branch_graph)
            for branch_graph in branch_graphs
        ])
        case_op._set_type_list_attr("Tout", branch_graphs[0].output_types)
        case_op._set_shape_list_attr("output_shapes",
                                     branch_graphs[0].output_shapes)
        case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]],
                             [t.shape for t in extra_branch_outputs[0]])

    # Resolve references to forward graph tensors in grad graphs and ensure
    # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
    branches_grad_inputs = [
        _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph,
        branch_grad_graph in zip(branch_graphs, branch_grad_graphs)
    ]

    # This modifies the graphs in branch_grad_graphs.
    _make_output_composite_tensors_match(_CASE, branch_grad_graphs)

    outputs = _build_case(case_op.inputs[0],
                          branch_grad_graphs,
                          branches_grad_inputs,
                          name="gradient")

    # The predicate has no gradient.
    return [None] + outputs
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation. Since the outputs of the two functions must match,
    # we wrap all the intermediates in optionals. Each intermediate output will
    # have a value iff its corresponding branch is taken.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
        # XLA does not yet support optionals, so output intermediates directly and
        # make them match via FakeParams, which can be converted to zeros in XLA.
        # TODO(skyewm,jpienaar): can XLA support optionals?
        extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
            true_graph, false_graph, true_intermediates, false_intermediates)
    else:
        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(
            true_graph, false_graph, wrapped_true_intermediates,
            wrapped_false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    # TODO(skyewm): somehow indicate it's a bug if this fails.
    _check_same_outputs(true_graph, false_graph)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)

    return tensors[:num_cond_outputs]