Exemplo n.º 1
0
def If(cond, inputs, then_branch, else_branch, name=None):
    r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).

  Args:
    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
      converted to a boolean according to the following rule: if the
      scalar is a numerical value, non-zero means True and zero means
      False; if the scalar is a string, non-empty means True and empty
      means False.
    inputs: A list of input tensors.
    then_branch: A function takes 'inputs' and returns a list of tensors,
        whose types are the same as what else_branch returns.
    else_branch: A function takes 'inputs' and returns a list of tensors.
        whose types are the same as what then_branch returns.
    name: A name for the operation (optional).

  Returns:
    A list of tensors returned by either then_branch(inputs)
    or else_branch(inputs).
  """
    # pylint: disable=protected-access
    return gen_functional_ops._if(
        cond,
        inputs, [_.type for _ in then_branch.definition.signature.output_arg],
        then_branch,
        else_branch,
        name=name)
Exemplo n.º 2
0
def If(cond, inputs, then_branch, else_branch, name=None):
  r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).

  Args:
    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
      converted to a boolean according to the following rule: if the
      scalar is a numerical value, non-zero means True and zero means
      False; if the scalar is a string, non-empty means True and empty
      means False.
    inputs: A list of input tensors.
    then_branch: A function takes 'inputs' and returns a list of tensors,
        whose types are the same as what else_branch returns.
    else_branch: A function takes 'inputs' and returns a list of tensors.
        whose types are the same as what then_branch returns.
    name: A name for the operation (optional).

  Returns:
    A list of tensors returned by either then_branch(inputs)
    or else_branch(inputs).
  """
  # pylint: disable=protected-access
  return gen_functional_ops._if(
      cond,
      inputs, [_.type for _ in then_branch.definition.signature.output_arg],
      then_branch,
      else_branch,
      name=name)
Exemplo n.º 3
0
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _check_same_outputs(true_graph, false_graph)

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match(true_graph, false_graph,
                                   true_inputs, false_inputs)

  # Create the If op.
  tensors = gen_functional_ops._if(  # pylint: disable=protected-access
      pred,
      cond_inputs, [t.dtype for t in true_graph.outputs],
      util.create_new_tf_function(true_graph),
      util.create_new_tf_function(false_graph),
      output_shapes=_get_output_shapes(true_graph.outputs,
                                       false_graph.outputs),
      name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)
  return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                            tensors)
Exemplo n.º 4
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of an If op produced by cond_v2."""
    true_graph, false_graph = _get_func_graphs(op)
    # Note: op.graph != ops.get_default_graph() when we are computing the gradient
    # of a nested cond.
    assert true_graph.outer_graph == op.graph
    assert false_graph.outer_graph == op.graph

    # Create grad functions that compute the gradient of the true/false forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    true_grad_graph = _create_grad_func(true_graph, grads,
                                        _get_grad_fn_name(true_graph))
    false_grad_graph = _create_grad_func(false_graph, grads,
                                         _get_grad_fn_name(false_graph))

    assert ([t.dtype for t in true_grad_graph.outputs
             ] == [t.dtype for t in false_grad_graph.outputs])

    # Resolve references to forward graph tensors in grad graphs and ensure
    # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
    true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
    false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

    # Make the inputs to true_grad_graph and false_grad_graph match. Note that
    # this modifies true_grad_graph and false_grad_graph.
    grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
                                     true_grad_inputs, false_grad_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # higher-order gradient computations.

    true_grad_intermediates = _get_intermediates(true_grad_graph)
    false_grad_intermediates = _get_intermediates(false_grad_graph)

    # Save the original number of gradient outputs to return.
    num_grad_outputs = len(true_grad_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
        true_grad_graph, false_grad_graph, true_grad_intermediates,
        false_grad_intermediates)

    true_grad_graph.outputs.extend(extra_true_grad_outputs)
    false_grad_graph.outputs.extend(extra_false_grad_outputs)

    # Create the gradient If op.
    tensors = gen_functional_ops._if(
        op.inputs[0],
        grad_inputs, [t.dtype for t in true_grad_graph.outputs],
        _create_new_tf_function(true_grad_graph),
        _create_new_tf_function(false_grad_graph),
        output_shapes=_get_output_shapes(true_grad_graph.outputs,
                                         false_grad_graph.outputs))

    # The predicate has no gradient.
    return [None] + tensors[:num_grad_outputs]
Exemplo n.º 5
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of an If op produced by cond_v2."""
  true_graph, false_graph = _get_func_graphs(op)
  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
  # of a nested cond.
  assert true_graph.outer_graph == op.graph
  assert false_graph.outer_graph == op.graph

  # Create grad functions that compute the gradient of the true/false forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  true_grad_graph = _create_grad_func(
      true_graph, grads, _get_grad_fn_name(true_graph))
  false_grad_graph = _create_grad_func(
      false_graph, grads, _get_grad_fn_name(false_graph))

  assert ([t.dtype for t in true_grad_graph.outputs] ==
          [t.dtype for t in false_grad_graph.outputs])

  # Resolve references to forward graph tensors in grad graphs and ensure
  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

  # Make the inputs to true_grad_graph and false_grad_graph match. Note that
  # this modifies true_grad_graph and false_grad_graph.
  grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
                                   true_grad_inputs, false_grad_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # higher-order gradient computations.

  true_grad_intermediates = _get_intermediates(true_grad_graph)
  false_grad_intermediates = _get_intermediates(false_grad_graph)

  # Save the original number of gradient outputs to return.
  num_grad_outputs = len(true_grad_graph.outputs)

  # Make the number/type of new intermediate outputs match.
  extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
      true_grad_graph, false_grad_graph,
      true_grad_intermediates, false_grad_intermediates)

  true_grad_graph.outputs.extend(extra_true_grad_outputs)
  false_grad_graph.outputs.extend(extra_false_grad_outputs)

  # Create the gradient If op.
  tensors = gen_functional_ops._if(
      op.inputs[0],
      grad_inputs, [t.dtype for t in true_grad_graph.outputs],
      _create_new_tf_function(true_grad_graph),
      _create_new_tf_function(false_grad_graph),
      output_shapes=_get_output_shapes(true_grad_graph.outputs,
                                       false_grad_graph.outputs))

  # The predicate has no gradient.
  return [None] + tensors[:num_grad_outputs]
Exemplo n.º 6
0
def If(cond, inputs, then_branch, else_branch, name=None):
  r"""output = Cond(inputs) ?

  then_branch(inputs) : else_branch(inputs).

  Args:
    cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
      converted to a boolean according to the following rule: if the scalar is a
        numerical value, non-zero means True and zero means False; if the scalar
        is a string, non-empty means True and empty means False.
    inputs: A list of input tensors.
    then_branch: A function takes 'inputs' and returns a list of tensors, whose
      types are the same as what else_branch returns.
    else_branch: A function takes 'inputs' and returns a list of tensors. whose
      types are the same as what then_branch returns.
    name: A name for the operation (optional).

  Returns:
    A list of tensors returned by either then_branch(inputs)
    or else_branch(inputs).
  """
  # pylint: disable=protected-access
  # Handle the Defun case until users have transitioned to tf.function. Note
  # that composites may need to be re-packed by the caller.
  if isinstance(then_branch, function._DefinedFunction):
    tlist = [_.type for _ in then_branch.definition.signature.output_arg]
    return gen_functional_ops._if(
        cond, inputs, tlist, then_branch, else_branch, name=name)

  # We assume that `then_branch` is a ConcreteFunction here.
  then_out = then_branch.structured_outputs
  else_out = else_branch.structured_outputs

  # Ensure then/else are the same type of composites to avoid an invalid call
  # to pack_sequence_as later on.
  nest.assert_same_structure(then_out, else_out, expand_composites=True)

  tlist = nest.flatten(then_branch.output_dtypes)
  ret = gen_functional_ops._if(
      cond, inputs, tlist, then_branch, else_branch, name=name)

  # Re-pack the outputs to restore any CompositeTensors
  return nest.pack_sequence_as(then_out, ret, expand_composites=True)
Exemplo n.º 7
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of an If op produced by cond_v2."""
    true_graph = op._true_graph
    false_graph = op._false_graph

    # Create grad functions that compute the gradient of the true/false forward
    # graphs. These functions will capture tensors from the forward pass
    # functions.
    true_grad_graph = _create_grad_func(true_graph, grads,
                                        "%sgrad" % true_graph.name)
    false_grad_graph = _create_grad_func(false_graph, grads,
                                         "%sgrad" % false_graph.name)

    assert ([t.dtype for t in true_grad_graph.outputs
             ] == [t.dtype for t in false_grad_graph.outputs])

    # Match up the captured grad function inputs with outputs of 'op' and other
    # external tensors.
    true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph)
    false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph)

    # Make the inputs to true_grad_graph and false_grad_graph match. Note that
    # this modifies true_grad_graph and false_grad_graph.
    grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
                                     true_grad_inputs, false_grad_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # higher-order gradient computations.

    true_grad_intermediates = _get_intermediates(true_grad_graph)
    false_grad_intermediates = _get_intermediates(false_grad_graph)

    # Save the original number of gradient outputs to return.
    num_grad_outputs = len(true_grad_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
        true_grad_graph, false_grad_graph, true_grad_intermediates,
        false_grad_intermediates)

    true_grad_graph.outputs.extend(extra_true_grad_outputs)
    false_grad_graph.outputs.extend(extra_false_grad_outputs)

    # Create the gradient If op.
    tensors = gen_functional_ops._if(
        op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
        _create_new_tf_function(true_grad_graph),
        _create_new_tf_function(false_grad_graph))
    tensors[0].op._true_graph = true_grad_graph
    tensors[0].op._false_graph = false_grad_graph

    # The predicate has no gradient.
    return [None] + tensors[:num_grad_outputs]
Exemplo n.º 8
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of an If op produced by cond_v2."""
  true_graph, false_graph = _get_func_graphs(op)

  # Create grad functions that compute the gradient of the true/false forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  true_grad_graph = _create_grad_func(
      true_graph, grads, _get_grad_fn_name(true_graph))
  false_grad_graph = _create_grad_func(
      false_graph, grads, _get_grad_fn_name(false_graph))

  assert ([t.dtype for t in true_grad_graph.outputs] ==
          [t.dtype for t in false_grad_graph.outputs])

  # Match up the captured grad function inputs with outputs of 'op' and other
  # external tensors.
  true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph)
  false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph)

  # Make the inputs to true_grad_graph and false_grad_graph match. Note that
  # this modifies true_grad_graph and false_grad_graph.
  grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
                                   true_grad_inputs, false_grad_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # higher-order gradient computations.

  true_grad_intermediates = _get_intermediates(true_grad_graph)
  false_grad_intermediates = _get_intermediates(false_grad_graph)

  # Save the original number of gradient outputs to return.
  num_grad_outputs = len(true_grad_graph.outputs)

  # Make the number/type of new intermediate outputs match.
  extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
      true_grad_graph, false_grad_graph,
      true_grad_intermediates, false_grad_intermediates)

  true_grad_graph.outputs.extend(extra_true_grad_outputs)
  false_grad_graph.outputs.extend(extra_false_grad_outputs)

  # Create the gradient If op.
  tensors = gen_functional_ops._if(
      op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
      _create_new_tf_function(true_grad_graph),
      _create_new_tf_function(false_grad_graph))

  # The predicate has no gradient.
  return [None] + tensors[:num_grad_outputs]
Exemplo n.º 9
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    with ops.name_scope(name) as scope:
        true_graph = function.func_graph_from_py_func(true_fn, [], [],
                                                      name="%s_true" % scope)
        false_graph = function.func_graph_from_py_func(false_fn, [], [],
                                                       name="%s_false" % scope)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.extra_inputs,
                                         false_graph.extra_inputs)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(pred,
                                         cond_inputs,
                                         [t.dtype for t in true_graph.outputs],
                                         _create_new_tf_function(true_graph),
                                         _create_new_tf_function(false_graph),
                                         name=scope)

        # TODO(b/79883549): if we could make Graphs from FunctionDefs, we wouldn't
        # need this extra state. Requiring extra state also prevents the ability to
        # take the gradient of deserialized If ops.
        tensors[0].op._true_graph = true_graph
        tensors[0].op._false_graph = false_graph

        return tensors[:num_cond_outputs]
Exemplo n.º 10
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    with ops.name_scope(name) as scope:
        true_graph = function.func_graph_from_py_func(true_fn, [], [],
                                                      name="%s_true" % scope)
        false_graph = function.func_graph_from_py_func(false_fn, [], [],
                                                       name="%s_false" % scope)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.extra_inputs,
                                         false_graph.extra_inputs)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(pred,
                                         cond_inputs,
                                         [t.dtype for t in true_graph.outputs],
                                         _create_new_tf_function(true_graph),
                                         _create_new_tf_function(false_graph),
                                         name=scope)

        return tensors[:num_cond_outputs]
Exemplo n.º 11
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  with ops.name_scope(name) as scope:
    true_graph = function.func_graph_from_py_func(true_fn, [], [],
                                                  name="%s_true" % scope)
    false_graph = function.func_graph_from_py_func(false_fn, [], [],
                                                   name="%s_false" % scope)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.extra_inputs,
                                     false_graph.extra_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(
        pred, cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        name=scope)

    return tensors[:num_cond_outputs]
Exemplo n.º 12
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      # Find the outer most graph for uniquing function names.
      # TODO(jpienaar): Make this work in eager mode.
      graph = ops.get_default_graph()
      while isinstance(graph, function.FuncGraph):
        graph = graph.outer_graph

      true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
      false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))

    true_graph = function.func_graph_from_py_func(
        true_name, true_fn, [], {})
    false_graph = function.func_graph_from_py_func(
        false_name, false_fn, [], {})
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    result = tuple(tensors[:num_cond_outputs])
    if len(result) == 1:
      return result[0]
    else:
      return result
Exemplo n.º 13
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    result = tuple(tensors[:num_cond_outputs])
    if len(result) == 1:
      return result[0]
    else:
      return result
Exemplo n.º 14
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    util.maybe_set_lowering_attr(tensors[0].op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
Exemplo n.º 15
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    # Identify if there is a caller device, & get the innermost if possible.
    device_stack = ops.get_default_graph()._device_function_stack
    caller_device = device_stack[-1] if device_stack else None

    caller_colocation_stack = ops.get_default_graph()._colocation_stack
    caller_container = ops.get_default_graph()._container
    caller_collection_ref = ops.get_default_graph()._collections

    func_name_prefix = scope.replace("/", "_")

    true_graph = _function.func_graph_from_py_func(
        true_fn, [], [],
        name="%strue" % func_name_prefix,
        device=caller_device,
        colocation_stack=caller_colocation_stack,
        collections_ref=caller_collection_ref,
        container=caller_container)
    false_graph = _function.func_graph_from_py_func(
        false_fn, [], [],
        name="%sfalse" % func_name_prefix,
        device=caller_device,
        colocation_stack=caller_colocation_stack,
        collections_ref=caller_collection_ref,
        container=caller_container)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.extra_inputs,
                                     false_graph.extra_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(
        pred, cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))

    return tensors[:num_cond_outputs]
Exemplo n.º 16
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      # Find the outer most graph for uniquing function names.
      # TODO(jpienaar): Make this work in eager mode.
      graph = ops.get_default_graph()
      while isinstance(graph, _function.FuncGraph):
        graph = graph.outer_graph

      true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
      false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))

    true_graph = _function.func_graph_from_py_func(
        true_name, true_fn, [], {})
    false_graph = _function.func_graph_from_py_func(
        false_name, false_fn, [], {})
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred, cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    return tuple(tensors[:num_cond_outputs])
Exemplo n.º 17
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of an If op produced by cond_v2."""
  # Get the if operator (this logic handles the case where op is a MockOp)
  if_op = op.outputs[0].op
  true_graph, false_graph = _get_func_graphs(if_op)
  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
  # of a nested cond.
  assert true_graph.outer_graph == if_op.graph
  assert false_graph.outer_graph == if_op.graph

  # Create grad functions that compute the gradient of the true/false forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  true_grad_graph = _create_grad_func(
      true_graph, grads, _get_grad_fn_name(true_graph))
  false_grad_graph = _create_grad_func(
      false_graph, grads, _get_grad_fn_name(false_graph))

  if (true_grad_graph.if_op_needs_rewrite or
      false_grad_graph.if_op_needs_rewrite):
    # Modify 'op' to output the intermediates needed by the grad functions. Note
    # that all needed intermediates are wrapped in optionals. Each optional
    # intermediate output will have a value iff its corresponding branch is
    # taken.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    if control_flow_util.InXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so output intermediates directly and
      # make them match via FakeParams, which can be converted to zeros in XLA.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      true_intermediates = true_grad_graph.xla_intermediates
      false_intermediates = false_grad_graph.xla_intermediates
      extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
          true_graph, false_graph, true_intermediates, false_intermediates)
    else:
      true_intermediates = true_grad_graph.wrapped_intermediates
      false_intermediates = false_grad_graph.wrapped_intermediates
      # Make outputs match by adding none optionals.
      extra_true_outputs, extra_false_outputs = _make_intermediates_match(
          true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    # TODO(skyewm): indicate it's an internal bug if this fails.
    _check_same_outputs(true_graph, false_graph)

    true_graph.name += "_rewritten"
    false_graph.name += "_rewritten"

    if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph))
    if_op._set_func_attr("else_branch",
                         util.create_new_tf_function(false_graph))
    if_op._set_type_list_attr("Tout", true_graph.output_types)
    if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes)
    if_op._add_outputs(
        [t.dtype for t in extra_true_outputs],
        [t.shape for t in extra_true_outputs])

  # Resolve references to forward graph tensors in grad graphs and ensure
  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

  # This modifies true_grad_graph and false_grad_graph.
  _make_output_composite_tensors_match(true_grad_graph, false_grad_graph)

  outputs = _build_cond(if_op.inputs[0], true_grad_graph, false_grad_graph,
                        true_grad_inputs, false_grad_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # higher-order gradient computations.

  true_grad_intermediates = _get_intermediates(true_grad_graph)
  false_grad_intermediates = _get_intermediates(false_grad_graph)

  # Save the original number of gradient outputs to return.
  num_grad_outputs = len(true_grad_graph.outputs)

  # Make the number/type of new intermediate outputs match.
  extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
      true_grad_graph, false_grad_graph,
      true_grad_intermediates, false_grad_intermediates)

  true_grad_graph.outputs.extend(extra_true_grad_outputs)
  false_grad_graph.outputs.extend(extra_false_grad_outputs)

  # Create the gradient If op.
  tensors = gen_functional_ops._if(
      op.inputs[0],
      grad_inputs, [t.dtype for t in true_grad_graph.outputs],
      util.create_new_tf_function(true_grad_graph),
      util.create_new_tf_function(false_grad_graph),
      output_shapes=_get_output_shapes(true_grad_graph.outputs,
                                       false_grad_graph.outputs))

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
Exemplo n.º 18
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    # Identify if there is a caller device, & get the innermost if possible.
    device_stack = ops.get_default_graph()._device_function_stack
    caller_device = device_stack[-1] if device_stack else None

    caller_colocation_stack = ops.get_default_graph()._colocation_stack
    caller_container = ops.get_default_graph()._container
    caller_collection_ref = ops.get_default_graph()._collections

    func_name_prefix = scope.replace("/", "_")

    true_graph = function.func_graph_from_py_func(
        true_fn, [], [],
        name="%strue" % func_name_prefix,
        device=caller_device,
        colocation_stack=caller_colocation_stack,
        collections_ref=caller_collection_ref,
        container=caller_container)
    false_graph = function.func_graph_from_py_func(
        false_fn, [], [],
        name="%sfalse" % func_name_prefix,
        device=caller_device,
        colocation_stack=caller_colocation_stack,
        collections_ref=caller_collection_ref,
        container=caller_container)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.extra_inputs,
                                     false_graph.extra_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(
        pred, cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        name=scope)

    return tensors[:num_cond_outputs]
Exemplo n.º 19
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        # Identify if there is a caller device, & get the innermost if possible.
        device_stack = ops.get_default_graph()._device_function_stack
        caller_device = device_stack[-1] if device_stack else None

        caller_colocation_stack = ops.get_default_graph()._colocation_stack
        caller_container = ops.get_default_graph()._container
        caller_collection_ref = ops.get_default_graph()._collections

        func_name_prefix = scope.replace("/", "_")

        true_graph = _function.func_graph_from_py_func(
            true_fn, [], [],
            name="%strue" % func_name_prefix,
            device=caller_device,
            colocation_stack=caller_colocation_stack,
            collections_ref=caller_collection_ref,
            container=caller_container)
        false_graph = _function.func_graph_from_py_func(
            false_fn, [], [],
            name="%sfalse" % func_name_prefix,
            device=caller_device,
            colocation_stack=caller_colocation_stack,
            collections_ref=caller_collection_ref,
            container=caller_container)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.extra_inputs,
                                         false_graph.extra_inputs)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(pred,
                                         cond_inputs,
                                         [t.dtype for t in true_graph.outputs],
                                         _create_new_tf_function(true_graph),
                                         _create_new_tf_function(false_graph),
                                         name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))

        return tensors[:num_cond_outputs]
Exemplo n.º 20
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            # Find the outer most graph for uniquing function names.
            # TODO(jpienaar): Make this work in eager mode.
            graph = ops.get_default_graph()
            while isinstance(graph, function.FuncGraph):
                graph = graph.outer_graph

            true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
            false_name = graph.unique_name(
                ("%sfalse" % scope).replace("/", "_"))

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = util.in_defun()

        true_graph = function.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name),
            add_control_dependencies=add_control_dependencies)
        false_graph = function.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name),
            add_control_dependencies=add_control_dependencies)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        # Return identities for each output of the If op, rather than the output of
        # the If op directly. This makes pruning work if the output of cond() is
        # fetched: the lowering pass converts the If outputs into IdentityN outputs,
        # which if fetched will cause all ops in the taken branch to be run (since
        # it takes all merge ops as input). After lowering, each output identity op
        # will end up with only the appropriate merge op as input.
        # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
        # correct output structure
        tensors = tuple(array_ops.identity(t) for t in tensors)

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result
Exemplo n.º 21
0
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _check_same_outputs(true_graph, false_graph)

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match(true_graph, false_graph,
                                   true_inputs, false_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # the gradient computation. Since the outputs of the two functions must match,
  # we wrap all the intermediates in optionals. Each intermediate output will
  # have a value iff its corresponding branch is taken.

  true_intermediates = _get_intermediates(true_graph)
  false_intermediates = _get_intermediates(false_graph)

  # Save the original number of outputs to return to the caller.
  num_cond_outputs = len(true_graph.outputs)

  if control_flow_util.InXlaContext(ops.get_default_graph()):
    # XLA does not yet support optionals, so output intermediates directly and
    # make them match via FakeParams, which can be converted to zeros in XLA.
    # TODO(skyewm,jpienaar): can XLA support optionals?
    extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
        true_graph, false_graph, true_intermediates, false_intermediates)
  else:
    # Wrap intermediates in optionals.
    wrapped_true_intermediates = _wrap_intermediates(true_graph,
                                                     true_intermediates)
    wrapped_false_intermediates = _wrap_intermediates(false_graph,
                                                      false_intermediates)

    # Make outputs match by adding none optionals.
    extra_true_outputs, extra_false_outputs = _make_intermediates_match(
        true_graph, false_graph,
        wrapped_true_intermediates, wrapped_false_intermediates)

  true_graph.outputs.extend(extra_true_outputs)
  false_graph.outputs.extend(extra_false_outputs)
  # TODO(skyewm): somehow indicate it's a bug if this fails.
  _check_same_outputs(true_graph, false_graph)

  # Create the If op.
  tensors = gen_functional_ops._if(  # pylint: disable=protected-access
      pred,
      cond_inputs, [t.dtype for t in true_graph.outputs],
      util.create_new_tf_function(true_graph),
      util.create_new_tf_function(false_graph),
      output_shapes=_get_output_shapes(true_graph.outputs,
                                       false_graph.outputs),
      name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)

  return tensors[:num_cond_outputs]
Exemplo n.º 22
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return tensors
Exemplo n.º 23
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation. Since the outputs of the two functions must match,
    # we wrap all the intermediates in optionals. Each intermediate output will
    # have a value iff its corresponding branch is taken.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
        # XLA does not yet support optionals, so output intermediates directly and
        # make them match via FakeParams, which can be converted to zeros in XLA.
        # TODO(skyewm,jpienaar): can XLA support optionals?
        extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
            true_graph, false_graph, true_intermediates, false_intermediates)
    else:
        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(
            true_graph, false_graph, wrapped_true_intermediates,
            wrapped_false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    # TODO(skyewm): somehow indicate it's a bug if this fails.
    _check_same_outputs(true_graph, false_graph)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)

    return tensors[:num_cond_outputs]