Beispiel #1
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Replace None gradients with zeros. This is needed because `grads` could have
  # None incoming gradients for the TensorLists. If we pass None's through, the
  # custom gradient of TensorListPopBack will create an EmptyTensorList inside
  # the FuncGraph which is undesirable.
  # TODO(b/80444525): There might be an issue with treating no gradient as zero
  # gradient in certain cases. Consider replacing None gradients with Zeros
  # for accumulators only.
  grads = [
      g if g is not None else array_ops.zeros_like(output)
      for g, output in zip(grads, op.outputs)
  ]

  body_grad_graph, args = _create_grad_func(
      body_graph, grads,
      _get_unique_name("%s_grad" % body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  grad_cond_name = _get_unique_name("%s_grad_cond" % op.name)
  cond_grad_graph = function.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  assert len(loop_vars) == len(body_grad_graph.inputs)
  assert len(loop_vars) == len(body_grad_graph.outputs)
  assert len(loop_vars) == len(cond_grad_graph.inputs)

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name=_get_unique_name("%s_grad" % op.name))

  _copy_handle_data(body_grad_graph.outputs, outputs)
  _maybe_set_lowering_attr(outputs[0].op)

  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  return outputs[2:2 + len(op.inputs)]
Beispiel #2
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Replace None gradients with zeros. This is needed because `grads` could have
  # None incoming gradients for the TensorLists. If we pass None's through, the
  # custom gradient of TensorListPopBack will create an EmptyTensorList inside
  # the FuncGraph which is undesirable.
  # TODO(b/80444525): There might be an issue with treating no gradient as zero
  # gradient in certain cases. Consider replacing None gradients with Zeros
  # for accumulators only.
  grads = [
      g if g is not None else array_ops.zeros_like(output)
      for g, output in zip(grads, op.outputs)
  ]

  body_grad_graph, args = _create_grad_func(
      body_graph, grads,
      _get_unique_name("%s_grad" % body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  cond_grad_graph = function.func_graph_from_py_func(
      _get_unique_name("%s_grad_cond" % op.name),
      grad_cond, loop_vars, {})

  assert len(loop_vars) == len(body_grad_graph.inputs)
  assert len(loop_vars) == len(body_grad_graph.outputs)
  assert len(loop_vars) == len(cond_grad_graph.inputs)

  outputs = gen_functional_ops._while(
      loop_vars,
      cond_v2._create_new_tf_function(cond_grad_graph),
      cond_v2._create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name=_get_unique_name("%s_grad" % op.name))

  _copy_handle_data(body_grad_graph.outputs, outputs)
  _maybe_set_lowering_attr(outputs[0].op)

  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  return outputs[2:2 + len(op.inputs)]
Beispiel #3
0
def wrap_function(fn, signature, name=None):
    """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the
      wrapped function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
    holder = VariableHolder(fn)
    fn = function.Function(function.func_graph_from_py_func(
        name,
        holder,
        args=None,
        kwargs=None,
        signature=signature,
        add_control_dependencies=False),
                           signature=signature)
    fn._variable_holder = holder
    return fn
Beispiel #4
0
def wrap_function(fn, signature, name=None):
  """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the
      wrapped function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
  holder = VariableHolder(fn)
  fn = function.Function(
      function.func_graph_from_py_func(
          name,
          holder,
          args=None, kwargs=None, signature=signature,
          add_control_dependencies=False),
      signature=signature)
  fn._variable_holder = holder
  return fn
Beispiel #5
0
def _create_grad_func(func_graph, grads, name, while_op):
    """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    func_graph: FuncGraph for the forward body function.
    grads: The incoming grads for `func_graph`'s outputs.
    name: Name of the returned gradient function.
    while_op: The forward While op.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
    assert len(func_graph.outputs) == len(grads)

    loop_counter = constant_op.constant(0.)
    # TODO(srbs): For nested while loops will need to lookup this value from
    # the accumulator of the enclosing while loop. For now use as is assuming
    # there is no nesting.
    num_iters_t = while_op.outputs[0]

    args = [loop_counter, num_iters_t] + grads

    # Note: The returned function does not have `args` in the list of
    # `external_captures`.
    grad_func_graph = function.func_graph_from_py_func(
        name,
        lambda *args: _grad_fn(func_graph, args),
        args, {},
        func_graph=_WhileBodyGradFuncGraph(name, func_graph))

    # Add the popped accumulators to the list of outputs.
    for internal_capture in grad_func_graph.internal_captures:
        grad_func_graph.outputs.append(
            grad_func_graph.popped_tensor_lists[internal_capture])

    return grad_func_graph, args
Beispiel #6
0
def _create_grad_func(func_graph, grads, name, while_op):
  """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    func_graph: FuncGraph for the forward body function.
    grads: The incoming grads for `func_graph`'s outputs.
    name: Name of the returned gradient function.
    while_op: The forward While op.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
  assert len(func_graph.outputs) == len(grads)

  loop_counter = constant_op.constant(0.)
  # TODO(srbs): For nested while loops will need to lookup this value from
  # the accumulator of the enclosing while loop. For now use as is assuming
  # there is no nesting.
  num_iters_t = while_op.outputs[0]

  args = [loop_counter, num_iters_t] + grads

  # Note: The returned function does not have `args` in the list of
  # `external_captures`.
  grad_func_graph = function.func_graph_from_py_func(
      name,
      lambda *args: _grad_fn(func_graph, args),
      args, {},
      func_graph=_WhileBodyGradFuncGraph(name, func_graph))

  # Add the popped accumulators to the list of outputs.
  for internal_capture in grad_func_graph.internal_captures:
    grad_func_graph.outputs.append(
        grad_func_graph.popped_tensor_lists[internal_capture])

  return grad_func_graph, args
Beispiel #7
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      # Find the outer most graph for uniquing function names.
      # TODO(jpienaar): Make this work in eager mode.
      graph = ops.get_default_graph()
      while isinstance(graph, function.FuncGraph):
        graph = graph.outer_graph

      true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
      false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))

    true_graph = function.func_graph_from_py_func(
        true_name, true_fn, [], {})
    false_graph = function.func_graph_from_py_func(
        false_name, false_fn, [], {})
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        _create_new_tf_function(true_graph),
        _create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    result = tuple(tensors[:num_cond_outputs])
    if len(result) == 1:
      return result[0]
    else:
      return result
Beispiel #8
0
def _create_grad_func(func_graph, grads, name):
  """Returns the FuncGraph representation of _grad_fn."""
  return function.func_graph_from_py_func(
      name, lambda: _grad_fn(func_graph, grads), [], {})
Beispiel #9
0
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
    """Like tf.while_loop, except emits a single While op."""
    flattened_loop_vars = nest.flatten(loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(loop_vars, shape_invariants)
        flattened_shapes = nest.flatten(shape_invariants)
    else:
        flattened_shapes = [t.shape for t in flattened_loop_vars]

    del shape_invariants

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")

        num_outputs = len(flattened_loop_vars)

        # Add loop counter needed for computing gradients.
        flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
                               ] + flattened_loop_vars

        flattened_shapes = [tensor_shape.scalar()] + flattened_shapes

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(unused_loop_counter, *loop_vars):
            return cond(*loop_vars)

        signature = [
            tensor_spec.TensorSpec(shape, t.dtype)
            for shape, t in zip(flattened_shapes, flattened_loop_vars)
        ]
        cond_graph = function.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            flattened_loop_vars, {},
            signature=signature,
            func_graph=util.WhileCondFuncGraph(cond_name))

        # Add external_captures of cond to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
        flattened_shapes = flattened_shapes + [
            t.shape for t in cond_graph.external_captures
        ]

        def wrapped_body(loop_counter, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:num_outputs] - Args for the original loop body.
          args[num_outputs:] - External captures of cond. These get passed
            through as is.

      Returns:
        A list of tensors the same length as args.
      """
            outputs = body(*args[:num_outputs])
            if not isinstance(outputs, collections.Sequence):
                outputs = [outputs]

            # Return the external_captures of cond_graph as is, i.e., treat them as
            # loop invariants.
            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1] + list(outputs) + list(
                args[num_outputs:])

        signature = [
            tensor_spec.TensorSpec(shape, t.dtype)
            for shape, t in zip(flattened_shapes, flattened_loop_vars)
        ]
        body_graph = function.func_graph_from_py_func(
            body_name,
            wrapped_body,
            flattened_loop_vars, {},
            signature=signature,
            func_graph=util.WhileBodyFuncGraph(body_name))
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture `external_captures` of `body_graph` in `cond_graph` so that it
        # expects to receive those as arguments.
        # TODO(srbs): Dedup tensors that are captured in both the cond and body.
        # This logic already exists in cond_v2.
        with cond_graph.as_default():
            for external_capture in body_graph.external_captures:
                cond_graph.capture(external_capture)

        # Export all tensors in the loop body that may be needed for gradient
        # computation. We do this by accumulating the intermediate values in
        # TensorLists.
        intermediate_tensors = _get_intermediates(body_graph)

        for intermediate_tensor in intermediate_tensors:
            # TODO(srbs): Cache and re-use empty tensor lists.
            tensor_list = list_ops.empty_tensor_list(
                element_dtype=intermediate_tensor.dtype,
                element_shape=_get_tensor_convertible_shape(
                    intermediate_tensor.shape))
            flattened_loop_vars.append(tensor_list)
            with cond_graph.as_default():
                # Add a placeholder to cond_graph's inputs corresponding to the
                # tensor_list.
                cond_graph.capture(tensor_list)
            with body_graph.as_default():
                # Push the intermediate tensor to the tensor list. This captures the
                # `tensor_list` as well.
                appended_tensor_list = list_ops.tensor_list_push_back(
                    tensor_list, intermediate_tensor)
                # Add this modified tensor list to the list of outputs.
                body_graph.outputs.append(appended_tensor_list)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
                             flattened_shapes[1:1 + num_outputs],
                             flattened_loop_vars[1:1 + num_outputs])
        outputs = gen_functional_ops._while(
            flattened_loop_vars,
            util.create_new_tf_function(cond_graph),
            util.create_new_tf_function(body_graph),
            output_shapes=[t.shape for t in body_graph.outputs],
            name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        _maybe_set_lowering_attr(outputs[0].op)

    # First var is loop counter.
    if num_outputs == 1:
        return outputs[1]
    else:
        return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
Beispiel #10
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            # Find the outer most graph for uniquing function names.
            # TODO(jpienaar): Make this work in eager mode.
            graph = ops.get_default_graph()
            while isinstance(graph, function.FuncGraph):
                graph = graph.outer_graph

            true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
            false_name = graph.unique_name(
                ("%sfalse" % scope).replace("/", "_"))

        true_graph = function.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name))
        false_graph = function.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name))
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result
Beispiel #11
0
def _create_grad_func(func_graph, grads, name):
    """Returns the FuncGraph representation of _grad_fn."""
    return function.func_graph_from_py_func(
        name,
        lambda: _grad_fn(func_graph, grads), [], {},
        func_graph=util.CondBranchFuncGraph(name))
Beispiel #12
0
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
  """Like tf.while_loop, except emits a single While op."""
  flattened_loop_vars = nest.flatten(loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(loop_vars, shape_invariants)
    flattened_shapes = nest.flatten(shape_invariants)
  else:
    flattened_shapes = [t.shape for t in flattened_loop_vars]

  del shape_invariants

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
      body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))

    num_outputs = len(flattened_loop_vars)

    # Add loop counter needed for computing gradients.
    flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
                          ] + flattened_loop_vars

    flattened_shapes = [tensor_shape.scalar()] + flattened_shapes

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(unused_loop_counter, *loop_vars):
      return cond(*loop_vars)

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    cond_graph = function.func_graph_from_py_func(
        cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature)

    # Add external_captures of cond to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
    flattened_shapes = flattened_shapes + [
        t.shape for t in cond_graph.external_captures
    ]

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:num_outputs] - Args for the original loop body.
          args[num_outputs:] - External captures of cond. These get passed
            through as is.

      Returns:
        A list of tensors the same length as args.
      """
      outputs = body(*args[:num_outputs])
      if not isinstance(outputs, collections.Sequence):
        outputs = [outputs]

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    body_graph = function.func_graph_from_py_func(
        body_name, wrapped_body, flattened_loop_vars, {}, signature=signature)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
    # expects to receive those as arguments.
    # TODO(srbs): Dedup tensors that are captured in both the cond and body.
    # This logic already exists in cond_v2.
    with cond_graph.as_default():
      for external_capture in body_graph.external_captures:
        cond_graph.capture(external_capture)

    # Export all tensors in the loop body that may be needed for gradient
    # computation. We do this by accumulating the intermediate values in
    # TensorLists.
    intermediate_tensors = _get_intermediates(body_graph)

    for intermediate_tensor in intermediate_tensors:
      # TODO(srbs): Cache and re-use empty tensor lists.
      tensor_list = list_ops.empty_tensor_list(
          element_dtype=intermediate_tensor.dtype,
          element_shape=_get_tensor_convertible_shape(
              intermediate_tensor.shape))
      flattened_loop_vars.append(tensor_list)
      with cond_graph.as_default():
        # Add a placeholder to cond_graph's inputs corresponding to the
        # tensor_list.
        cond_graph.capture(tensor_list)
      with body_graph.as_default():
        # Push the intermediate tensor to the tensor list. This captures the
        # `tensor_list` as well.
        appended_tensor_list = list_ops.tensor_list_push_back(
            tensor_list,
            intermediate_tensor)
        # Add this modified tensor list to the list of outputs.
        body_graph.outputs.append(appended_tensor_list)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
                         flattened_shapes[1:1 + num_outputs],
                         flattened_loop_vars[1:1 + num_outputs])
    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        cond_v2._create_new_tf_function(cond_graph),
        cond_v2._create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    _maybe_set_lowering_attr(outputs[0].op)

  # First var is loop counter.
  if num_outputs == 1:
    return outputs[1]
  else:
    return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
Beispiel #13
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = util.in_defun()

        true_graph = function.func_graph_from_py_func(
            true_name,
            true_fn, [], {},
            func_graph=util.CondBranchFuncGraph(true_name),
            add_control_dependencies=add_control_dependencies)
        false_graph = function.func_graph_from_py_func(
            false_name,
            false_fn, [], {},
            func_graph=util.CondBranchFuncGraph(false_name),
            add_control_dependencies=add_control_dependencies)
        _check_same_outputs(true_graph, false_graph)

        # Add inputs to true_graph and false_graph to make them match. Note that
        # this modifies true_graph and false_graph.
        cond_inputs = _make_inputs_match(true_graph, false_graph,
                                         true_graph.external_captures,
                                         false_graph.external_captures)

        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Save the original number of outputs to return to the caller.
        num_cond_outputs = len(true_graph.outputs)

        # Make the number/type of new intermediate outputs match.
        extra_true_outputs, extra_false_outputs = _pad_params(
            true_graph, false_graph, true_intermediates, false_intermediates)

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)

        # Create the If op.
        tensors = gen_functional_ops._if(  # pylint: disable=protected-access
            pred,
            cond_inputs, [t.dtype for t in true_graph.outputs],
            util.create_new_tf_function(true_graph),
            util.create_new_tf_function(false_graph),
            output_shapes=_get_output_shapes(true_graph.outputs,
                                             false_graph.outputs),
            name=scope)

        # Set the flag to enable lowering on the `if` op if necessary
        # Lowering allows cond_v2 to avoid some of the limitations of Functions,
        # allowing users to specify devices & colocation inside of cond_v2 branches,
        # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
        # This brings cond_v2 closer to feature parity with tf.cond.
        #
        # However, we do not lower `If` in the XLA context because it is easier for
        # XLA to apply its own optimizations when dealing with un-lowered `If`
        # operators than with lowered switch/merge control flow.
        #
        # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
        if_op = tensors[0].op
        if not control_flow_util.IsInXLAContext(if_op):
            # pylint: disable=protected-access
            if_op._set_attr("_lower_using_switch_merge",
                            attr_value_pb2.AttrValue(b=True))
            # pylint: enable=protected-access

        # Return identities for each output of the If op, rather than the output of
        # the If op directly. This makes pruning work if the output of cond() is
        # fetched: the lowering pass converts the If outputs into IdentityN outputs,
        # which if fetched will cause all ops in the taken branch to be run (since
        # it takes all merge ops as input). After lowering, each output identity op
        # will end up with only the appropriate merge op as input.
        # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
        # correct output structure
        tensors = tuple(array_ops.identity(t) for t in tensors)

        result = tuple(tensors[:num_cond_outputs])
        if len(result) == 1:
            return result[0]
        else:
            return result