Exemplo n.º 1
0
def wrap_cached_variables(concrete_function):
    """Wraps the concrete function if it uses cached read tensors.

  This function creates a new concrete function that captures variables
  instead of the cached read tensors.

  Args:
    concrete_function: A Concrete function that maybe captures cached read
      tensors.

  Returns:
    A concrete function that wraps the original concrete function, which
    captures variables instead. If the original function did not capture any
    cached values, then the function is not wrapped and the original object is
    returned.
  """
    outer_graph = func_graph_module.FuncGraph("{}_no_cache".format(
        concrete_function.graph.name))
    captures = concrete_function.graph._captures  # pylint: disable=protected-access
    mapped_captures = None
    remapped_captures = {}

    # Update the external captures to use read tensors generated in the outer
    # graph.
    with outer_graph.as_default():
        for capture, placeholder in concrete_function.graph.captures:
            cached_variable = getattr(capture, "_cached_variable", None)
            if cached_variable is None:
                continue
            cached_variable = cached_variable()
            new_cached_value = cached_variable.read_value()
            remapped_captures[id(capture)] = captures[id(capture)]
            captures[id(capture)] = (new_cached_value, placeholder)
            mapped_captures = True

    if not mapped_captures:
        return concrete_function

    inner_concrete = defun.ConcreteFunction(concrete_function.graph)

    def wrap_function(*args):
        return inner_concrete._call_flat(args, inner_concrete.captured_inputs)  # pylint:disable=protected-access

    args = nest.flatten(concrete_function.structured_input_signature,
                        expand_composites=True)
    func_graph_module.func_graph_from_py_func(None,
                                              wrap_function,
                                              args=tuple(args),
                                              kwargs={},
                                              func_graph=outer_graph)
    fn = defun.ConcreteFunction(outer_graph,
                                function_spec=concrete_function._function_spec)  # pylint: disable=protected-access
    fn._arg_keywords = concrete_function._arg_keywords  # pylint: disable=protected-access
    fn._num_positional_args = concrete_function._num_positional_args  # pylint: disable=protected-access

    # Return the captures to their original values
    for key, capture in remapped_captures.items():
        captures[key] = capture
    return fn
Exemplo n.º 2
0
    def testMaybeSetStaticShapeScalarShape(self):
        def reshape():
            v = array_ops.placeholder(dtypes.float32)
            t = array_ops.reshape(v, [-1])
            return t

        with self.disableSetStaticShape():
            graph_without_shape_propagation = func_graph.func_graph_from_py_func(
                "without_shape_propagation", reshape, [], {})
        graph_with_shape_propagation = func_graph.func_graph_from_py_func(
            "with_shape_propagation", reshape, [], {})
        self.assertCountEqual([
            op.type for op in graph_without_shape_propagation.get_operations()
        ], [op.type for op in graph_with_shape_propagation.get_operations()])
Exemplo n.º 3
0
    def testMaybeSetStaticShape(self):
        shape = constant_op.constant([2, 5], dtype=dtypes.int32)

        def reshape():
            v = array_ops.zeros([10])
            return array_ops.reshape(v, shape)

        with self.disableSetStaticShape():
            graph_without_shape_propagation = func_graph.func_graph_from_py_func(
                "without_shape_propagation", reshape, [], {})
        graph_with_shape_propagation = func_graph.func_graph_from_py_func(
            "with_shape_propagation", reshape, [], {})
        self.assertCountEqual([
            op.type for op in graph_without_shape_propagation.get_operations()
        ], [op.type for op in graph_with_shape_propagation.get_operations()])
Exemplo n.º 4
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)
        if (tensor_util.is_tensor(pred)
                and (pred.shape.dims is None or pred.shape.dims)):
            pred = array_ops.squeeze_v2(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        verify_captures(_COND, [true_graph, false_graph])
        return _build_cond(pred,
                           true_graph,
                           false_graph,
                           true_graph.external_captures,
                           false_graph.external_captures,
                           building_gradient=False,
                           name=scope)
Exemplo n.º 5
0
def _compile_function(func,
                      args,
                      scope,
                      control_outputs,
                      allow_external_captures=False):
  parent_graph = ops.get_default_graph()
  # Automatic control dependencies are added in defuns, but not in v1
  # graphs. Propagate that behavior here.
  add_control_dependencies = parent_graph._add_control_dependencies  # pylint: disable=protected-access

  # Functions inherit frontend attributes and the gradient override map from the
  # parent graph.
  proto = xla_data_pb2.FrontendAttributes()
  value = parent_graph._attr_scope_map.get(scopes.FRONTEND_ATTRIBUTES_NAME)  # pylint: disable=protected-access
  if value:
    proto.ParseFromString(value.s)
  attribute = attr_value_pb2.AttrValue(s=proto.SerializeToString())
  gradient_override_map = dict(parent_graph._gradient_override_map)  # pylint: disable=protected-access

  def func_wrapper(*args, **kwargs):
    # Add the frontend attributes to the current attributes.
    g = ops.get_default_graph()
    attributes = dict(g._attr_scope_map)  # pylint: disable=protected-access
    attributes[scopes.FRONTEND_ATTRIBUTES_NAME] = attribute

    with g._attr_scope(attributes):  # pylint: disable=protected-access
      with g.gradient_override_map(gradient_override_map):
        return func(*args, **kwargs)

  func_name = util.unique_fn_name(scope, "func")
  captured_args = ops.convert_n_to_tensor(args)

  # Compile the function to a graph.
  func_graph = func_graph_module.func_graph_from_py_func(
      func_name,
      func_wrapper,
      captured_args, {},
      add_control_dependencies=add_control_dependencies)

  # Add the external captures (resources) to arguments.
  for t in func_graph.external_captures:
    if not allow_external_captures and t.dtype != dtypes.resource:
      raise _InvalidCaptureException(t.name)
  captured_args += func_graph.external_captures

  # Add any control outputs.  Autograph will add control outputs to the graph
  # automatically, so only add ones which are not already present.
  for o in control_outputs:
    if not o in func_graph.control_outputs:
      func_graph.control_outputs.extend([o])

  # Fix shape inference for the gradients and extract_outside_compilation_pass.
  for op in func_graph.get_operations():
    output_shapes = [out.get_shape() for out in op.outputs]
    # pylint: disable=protected-access
    op._set_shape_list_attr("_output_shapes", output_shapes)
    op._set_shape_list_attr("_xla_inferred_shapes", output_shapes)
    # pylint: enable=protected-access

  return func_graph, captured_args
Exemplo n.º 6
0
  def testMaybeSetStaticShape(self):
    shape = constant_op.constant([2, 5], dtype=dtypes.int32)

    def reshape():
      v = array_ops.zeros([10])
      return array_ops.reshape(v, shape)
    # This test needs a placeholder which means we need to construct a graph.
    with ops.Graph().as_default():
      with self.disableSetStaticShape():
        graph_without_shape_propagation = func_graph.func_graph_from_py_func(
            "without_shape_propagation", reshape, [], {})
      graph_with_shape_propagation = func_graph.func_graph_from_py_func(
          "with_shape_propagation", reshape, [], {})
      self.assertCountEqual(
          [op.type for op in graph_without_shape_propagation.get_operations()],
          [op.type for op in graph_with_shape_propagation.get_operations()])
Exemplo n.º 7
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Replace None gradients with zeros. This is needed because `grads` could have
  # None incoming gradients for the TensorLists. If we pass None's through, the
  # custom gradient of TensorListPopBack will create an EmptyTensorList inside
  # the FuncGraph which is undesirable.
  # TODO(b/80444525): There might be an issue with treating no gradient as zero
  # gradient in certain cases. Consider replacing None gradients with Zeros
  # for accumulators only.
  grads = [
      g if g is not None else array_ops.zeros_like(output)
      for g, output in zip(grads, op.outputs)
  ]

  body_grad_graph, args = _create_grad_func(
      body_graph, grads,
      util.unique_grad_fn_name(body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  assert len(loop_vars) == len(body_grad_graph.inputs)
  assert len(loop_vars) == len(body_grad_graph.outputs)
  assert len(loop_vars) == len(cond_grad_graph.inputs)

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name="%s_grad" % op.name)

  _copy_handle_data(body_grad_graph.outputs, outputs)
  _maybe_set_lowering_attr(outputs[0].op)

  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  return outputs[2:2 + len(op.inputs)]
Exemplo n.º 8
0
def indexed_case(branch_index, branch_fns, name="indexed_case"):
  """Like conv_v2, except emits a Case op instead of an If."""
  if isinstance(branch_index, int):
    raise TypeError("branch_index must not be a Python int", branch_index)

  with ops.name_scope(name) as scope:
    branch_names = [
        util.unique_fn_name(scope, "branch{}".format(b))
        for b in range(len(branch_fns))
    ]

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    branch_index = ops.convert_to_tensor(branch_index, name="branch_index")

    branch_graphs = []
    for branch_name, branch_fn in zip(branch_names, branch_fns):
      branch_graphs.append(
          func_graph_module.func_graph_from_py_func(
              branch_name,
              branch_fn,
              [],
              {},
              func_graph=util.CondBranchFuncGraph(
                  branch_name,
                  collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
              add_control_dependencies=add_control_dependencies,
              op_return_value=branch_index))

    verify_captures(_CASE, branch_graphs)
    return _build_case(
        branch_index,
        branch_graphs, [g.external_captures for g in branch_graphs],
        name=scope)
Exemplo n.º 9
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        outputs = _build_cond(pred,
                              true_graph,
                              false_graph,
                              true_graph.external_captures,
                              false_graph.external_captures,
                              name=scope)

        return func_graph_module.pack_sequence_as(
            true_graph.structured_outputs, outputs)
Exemplo n.º 10
0
def wrap_function(fn, signature, name=None):
  """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the wrapped
      function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
  holder = VariableHolder(fn)
  func_graph_name = "wrapped_function"
  if name is not None:
    func_graph_name = "wrapped_function_" + name
  return WrappedFunction(
      func_graph.func_graph_from_py_func(
          func_graph_name,
          holder,
          args=None,
          kwargs=None,
          signature=signature,
          add_control_dependencies=False,
          collections={}),
      variable_holder=holder,
      signature=signature)
Exemplo n.º 11
0
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
                      maximum_iterations):
    """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grads: The incoming grads for `ys`.
    cond_graph: FuncGraph for the forward cond function.
    body_graph: FuncGraph for the forward body function.
    name: Name of the returned gradient function.
    while_op: The forward While op.
    maximum_iterations: Tensor. The maximum number of iterations.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
    assert len(ys) == len(grads)

    total_iters = while_op.outputs[0]
    counter = constant_op.constant(0,
                                   dtype=total_iters.dtype,
                                   name="grad_counter")

    args = [counter, maximum_iterations, total_iters] + list(grads)
    # Note: The returned function does not have `args` in the list of
    # `external_captures`.
    grad_func_graph = func_graph_module.func_graph_from_py_func(
        name,
        lambda *args: _grad_fn(ys, xs, args, body_graph),
        args, {},
        func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
                                           maximum_iterations, while_op))

    # Add the popped accumulators to the list of outputs.
    for internal_capture in grad_func_graph.internal_captures:
        if internal_capture in grad_func_graph.popped_tensor_lists:
            new_output = grad_func_graph.popped_tensor_lists[internal_capture]
        elif internal_capture.dtype == dtypes.resource:
            new_output = internal_capture
        else:
            raise ValueError(
                "Tensor %s is in list of internal_captures but is"
                " neither a resource nor is in popped_tensor_lists." %
                str(internal_capture))
        grad_func_graph.outputs.append(new_output)
        grad_func_graph.structured_outputs.append(new_output)

    return grad_func_graph, args
Exemplo n.º 12
0
    def wrap_function(self, fn, signature, name=None):
        """Wrap a TF 1.X function and save to functions dictionary."""
        func_graph.func_graph_from_py_func(
            None,  # Name is unused.
            self._variable_holder.call_with_variable_creator_scope(fn),
            args=None,
            kwargs=None,
            signature=signature,
            add_control_dependencies=False,
            func_graph=self.graph)

        # This code relies on questional behavior from `func_graph_from_py_func`.
        # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
        # and structured outputs are overwritten. Pretty sure this is a bug,
        # because structured outputs doesn't match up with the outputs...
        fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
        fn_outputs = self.graph.structured_outputs

        wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs)
        name = name or fn.__name__
        self._functions[name] = wrapped_function
        return wrapped_function
Exemplo n.º 13
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)

    outputs = _build_cond(pred, true_graph, false_graph,
                          true_graph.external_captures,
                          false_graph.external_captures,
                          name=scope)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              outputs)
Exemplo n.º 14
0
def wrap_function(fn, signature, name=None):
  """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the
      wrapped function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
  holder = VariableHolder(fn)
  fn = function.Function(
      func_graph.func_graph_from_py_func(
          name,
          holder,
          args=None, kwargs=None, signature=signature,
          add_control_dependencies=False),
      signature=signature)
  fn._variable_holder = holder
  return fn
Exemplo n.º 15
0
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
                      max_iters):
  """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grads: The incoming grads for `ys`.
    cond_graph: FuncGraph for the forward cond function.
    body_graph: FuncGraph for the forward body function.
    name: Name of the returned gradient function.
    while_op: The forward While op.
    max_iters: the maximum number of iterations, or None if no limit.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
  assert len(ys) == len(grads)

  total_iters = while_op.outputs[0]
  counter = constant_op.constant(
      0, dtype=total_iters.dtype, name="grad_counter")

  args = [counter, total_iters] + list(grads)
  # Note: The returned function does not have `args` in the list of
  # `external_captures`.
  grad_func_graph = func_graph_module.func_graph_from_py_func(
      name,
      lambda *args: _grad_fn(ys, xs, args, body_graph),
      args, {},
      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
                                         max_iters))

  # Add the popped accumulators to the list of outputs.
  for internal_capture in grad_func_graph.internal_captures:
    if internal_capture in grad_func_graph.popped_tensor_lists:
      new_output = grad_func_graph.popped_tensor_lists[internal_capture]
    elif internal_capture.dtype == dtypes.resource:
      new_output = internal_capture
    else:
      raise ValueError("Tensor %s is in list of internal_captures but is"
                       " neither a resource nor is in popped_tensor_lists." %
                       str(internal_capture))
    grad_func_graph.outputs.append(new_output)
    grad_func_graph.structured_outputs.append(new_output)

  return grad_func_graph, args
Exemplo n.º 16
0
    def _wrap_function(self,
                       fn,
                       args=None,
                       kwargs=None,
                       signature=None,
                       name=None):
        """Internal wrap function method with extended func_graph arguments."""
        fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
            self._variable_holder.call_with_variable_creator_scope(fn))

        func_graph.func_graph_from_py_func(
            None,  # Name is unused.
            fn_with_filter_and_scope,
            args=args,
            kwargs=kwargs,
            signature=signature,
            add_control_dependencies=False,
            func_graph=self.graph)

        # This code relies on questional behavior from `func_graph_from_py_func`.
        # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
        # and structured outputs are overwritten. Pretty sure this is a bug,
        # because structured outputs doesn't match up with the outputs...
        fn_inputs = self.graph.inputs[:-len(self.graph.captures)]

        # Return filtered ops to the flattened outputs.
        flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
        for index, op in returned_ops.items():
            flat_fn_outputs[index] = op
        fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
                                           flat_fn_outputs)

        name = name or fn.__name__
        wrapped_function = self._wrapped_function.prune(
            fn_inputs, fn_outputs, name, self.graph.structured_input_signature)
        self._functions[name] = wrapped_function
        return wrapped_function
Exemplo n.º 17
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)

    verify_captures(true_graph, false_graph)
    return _build_cond(pred, true_graph, false_graph,
                       true_graph.external_captures,
                       false_graph.external_captures,
                       name=scope)
Exemplo n.º 18
0
  def _wrap_function(self,
                     fn,
                     args=None,
                     kwargs=None,
                     signature=None,
                     name=None):
    """Internal wrap function method with extended func_graph arguments."""
    fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
        self._variable_holder.call_with_variable_creator_scope(fn))

    func_graph.func_graph_from_py_func(
        None,  # Name is unused.
        fn_with_filter_and_scope,
        args=args,
        kwargs=kwargs,
        signature=signature,
        add_control_dependencies=False,
        func_graph=self.graph)

    # This code relies on questional behavior from `func_graph_from_py_func`.
    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
    # and structured outputs are overwritten. Pretty sure this is a bug,
    # because structured outputs doesn't match up with the outputs...
    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]

    # Return filtered ops to the flattened outputs.
    flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
    for index, op in returned_ops.items():
      flat_fn_outputs[index] = op
    fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
                                       flat_fn_outputs)

    name = name or fn.__name__
    wrapped_function = self._wrapped_function.prune(
        fn_inputs, fn_outputs, name, self.graph.structured_input_signature)
    self._functions[name] = wrapped_function
    return wrapped_function
Exemplo n.º 19
0
def _create_grad_func(ys, xs, grads, func_graph, name, while_op):
    """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grads: The incoming grads for `ys`.
    func_graph: FuncGraph for the forward body function.
    name: Name of the returned gradient function.
    while_op: The forward While op.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
    assert len(ys) == len(grads)

    counter = constant_op.constant(0.)
    # TODO(srbs): For nested while loops will need to lookup this value from
    # the accumulator of the enclosing while loop. For now use as is assuming
    # there is no nesting.
    total_iters = while_op.outputs[0]

    args = [counter, total_iters] + list(grads)
    # Note: The returned function does not have `args` in the list of
    # `external_captures`.
    grad_func_graph = func_graph_module.func_graph_from_py_func(
        name,
        lambda *args: _grad_fn(ys, xs, args, func_graph),
        args, {},
        func_graph=_WhileBodyGradFuncGraph(name, func_graph, while_op))

    # Add the popped accumulators to the list of outputs.
    for internal_capture in grad_func_graph.internal_captures:
        if internal_capture in grad_func_graph.popped_tensor_lists:
            grad_func_graph.outputs.append(
                grad_func_graph.popped_tensor_lists[internal_capture])
        elif internal_capture.dtype == dtypes.resource:
            grad_func_graph.outputs.append(internal_capture)
        else:
            raise ValueError(
                "Tensor %s is in list of internal_captures but is"
                " neither a resource nor is in popped_tensor_lists." %
                str(internal_capture))

    return grad_func_graph, args
Exemplo n.º 20
0
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
    """Create a `ConcreteFunction` from `args` and `kwargs`."""

    self.tracing_count += 1
    if self.input_signature is None:
        arglen = len(args)
    else:
        arglen = len(self.input_signature)
    base_arg_names = self._function_spec.arg_names[:arglen]
    num_missing_args = arglen - len(self._function_spec.arg_names)
    missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
    # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
    # where arg is based on the self._function_spec.vararg_name.
    missing_arg_names = [
        "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
    ]
    arg_names = base_arg_names + missing_arg_names

    graph_function = _function.ConcreteFunction(
        func_graph_module.func_graph_from_py_func(
            self._name,
            self._python_function,
            args,
            kwargs,
            self.input_signature,
            autograph=self._autograph,
            autograph_options=self._autograph_options,
            arg_names=arg_names,
            override_flat_arg_shapes=override_flat_arg_shapes,
            capture_by_value=self._capture_by_value,
            add_control_dependencies=False,
        ),
        self._function_attributes,
        # Tell the ConcreteFunction to clean up its graph once it goes out of
        # scope. This is not the default behavior since it gets used in some
        # places (like Keras) where the FuncGraph lives longer than the
        # ConcreteFunction.
        shared_func_graph=False,
    )
    return graph_function
Exemplo n.º 21
0
def _create_grad_func(func_graph, grads, name, while_op):
  """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    func_graph: FuncGraph for the forward body function.
    grads: The incoming grads for `func_graph`'s outputs.
    name: Name of the returned gradient function.
    while_op: The forward While op.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
  assert len(func_graph.outputs) == len(grads)

  loop_counter = constant_op.constant(0.)
  # TODO(srbs): For nested while loops will need to lookup this value from
  # the accumulator of the enclosing while loop. For now use as is assuming
  # there is no nesting.
  num_iters_t = while_op.outputs[0]

  args = [loop_counter, num_iters_t] + grads

  # Note: The returned function does not have `args` in the list of
  # `external_captures`.
  grad_func_graph = func_graph_module.func_graph_from_py_func(
      name,
      lambda *args: _grad_fn(func_graph, args),
      args, {},
      func_graph=_WhileBodyGradFuncGraph(name, func_graph))

  # Add the popped accumulators to the list of outputs.
  for internal_capture in grad_func_graph.internal_captures:
    grad_func_graph.outputs.append(
        grad_func_graph.popped_tensor_lists[internal_capture])

  return grad_func_graph, args
Exemplo n.º 22
0
    def wrap_input_receiver_fn(self, input_receiver_fn):
        """Converts an input receiver function to one or more concrete functions.

    Input receiver functions are python functions with no arguments.
    Placeholders are created within the function and used to receive inputs to
    the model.

    The function (or multiple functions) generated depends on the InputReceiver
    object returned by `input_receiver_fn`.

    Generally, the returned function will have inputs and outputs:
      input_receiver(**receiver_tensors) --> features

    or (if the InputReceiver returns labels):
      input_receiver(**receiver_tensors) --> features, labels

    __Alternate Receiver Tensors__

    The InputReceiver may have alternate receiver tensors, in which case
    additional concrete functions are generated. Example:
      InputReceiver.receiver_tensors_alternatives = {
        'alt_input_1': Tensor,
        'alt_input_2': {
          'tensor_1': Tensor,
          'tensor_2': Tensor
        }
      }

    This will generate concrete functions:
      input_receiver_alt_input_1(input) --> features
      input_receiver_alt_input_2(tensor_1, tensor_2) --> features

    Args:
      input_receiver_fn: a no-argument function that returns an `InputReceiver`
        object.

    Returns:
      A list of tuples of (concrete function, receiver name). The name of the
      default input receiver is `None`.
    """
        ret = [None]

        def fn():
            ret[0] = input_receiver = input_receiver_fn()
            features = input_receiver.features
            labels = getattr(input_receiver, 'labels', None)

            if labels is None:
                return features
            return features, labels

        func_graph.func_graph_from_py_func(
            None,  # Name is unused.
            self._variable_holder.call_with_variable_creator_scope(fn),
            args=None,
            kwargs=None,
            signature=[],
            add_control_dependencies=False,
            func_graph=self.graph)

        functions = []
        input_receiver = ret[0]

        wrapped_input_receiver_fn = _prune_receiver_tensors(
            self._wrapped_function,
            receiver_tensors=input_receiver.receiver_tensors,
            outputs=self.graph.structured_outputs,
            name=_input_receiver_fn_name(None))
        functions.append((wrapped_input_receiver_fn, None))

        receiver_tensors_alternatives = getattr(
            input_receiver, 'receiver_tensors_alternatives', None)

        if receiver_tensors_alternatives:
            for receiver_name, receiver_tensors_alt in (
                    six.iteritems(receiver_tensors_alternatives)):
                receiver_tensors_alt = _canonicalize_receiver_tensors(
                    receiver_tensors_alt)
                wrapped_input_receiver_fn = _prune_receiver_tensors(
                    self._wrapped_function,
                    receiver_tensors=receiver_tensors_alt,
                    outputs=self.graph.structured_outputs,
                    name=_input_receiver_fn_name(receiver_name))
                functions.append((wrapped_input_receiver_fn, receiver_name))
        return functions
Exemplo n.º 23
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  # Note that op is not always the same as while_op because the gradient tape,
  # for eager mode compatibility, forgets information about the proper op. Since
  # the loop cannot run in eager mode, however, we can safely introspect into
  # the graph here.
  while_op = op.outputs[0].op
  cond_graph = _get_graph(while_op, "cond")
  body_graph = _get_graph(while_op, "body")
  orig_num_params = len(body_graph.outputs)

  maximum_iterations = op.get_attr(
      "_maximum_iterations") if _is_in_xla_context() else None
  parallel_iterations = op.get_attr("parallel_iterations")
  assert not _is_in_xla_context() or maximum_iterations is not None
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)

  grads = [_preprocess_grad(grad, body_out, while_out)
           for grad, body_out, while_out
           in zip(grads, body_graph.outputs, while_op.outputs)]

  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, cond_graph, body_graph,
      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

  if body_grad_graph.while_op_needs_rewrite:
    # Modify 'op' to output the intermediate accumulators needed by the grad
    # function.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    cond_graph.name += "_rewritten"
    body_graph.name += "_rewritten"

    new_inputs = body_grad_graph.empty_tensor_lists
    new_outputs = body_graph.outputs[orig_num_params:]

    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
    while_op._set_type_list_attr("T", body_graph.output_types)
    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
    while_op._add_while_inputs(new_inputs)
    while_op._add_outputs([t.dtype for t in new_outputs],
                          [t.shape for t in new_outputs])
    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                           while_op)
  loop_vars = args + captured_inputs

  # This modifies body_grad_graph.
  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
      grads, body_grad_graph, loop_vars, while_op.inputs)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      parallel_iterations=parallel_iterations,
      name="%s_grad" % while_op.name)
  grad_op = outputs[0].op

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(grad_op)
  _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]
  return _get_structured_grad_output(outputs, grads, body_grad_graph)
Exemplo n.º 24
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
    """Like tf.while_loop, except emits a single While op."""
    maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
    # Keep the original loop_vars around to know which args were TensorArrays.
    orig_loop_vars = loop_vars
    # Cache its length since we use it at multiple places below.
    len_orig_loop_vars = len(orig_loop_vars)

    # Convert TensorArrays to their flow variables. These get converted back to
    # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
    # `wrapped_body` below.
    loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
    loop_vars = nest.map_structure(
        ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(orig_loop_vars, shape_invariants)
    else:
        shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")

        loop_counter = constant_op.constant(
            0,
            dtype=maximum_iterations.dtype
            if maximum_iterations is not None else None,
            name="loop_counter")
        # Add loop counter needed for computing gradients.
        loop_vars = [loop_counter] + loop_vars

        shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                                   ]) + shape_invariants

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(loop_counter, *args):
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            if maximum_iterations is None:
                return cond(*_pack_sequence_as(orig_loop_vars, args))
            else:
                return math_ops.logical_and(
                    loop_counter < maximum_iterations,
                    cond(*_pack_sequence_as(orig_loop_vars, args)))

        cond_graph = func_graph_module.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileCondFuncGraph(cond_name),
            add_control_dependencies=add_control_dependencies)

        # Add external_captures of cond to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + cond_graph.external_captures
        shape_invariants = shape_invariants + type(shape_invariants)(
            [t.shape for t in cond_graph.external_captures])

        def wrapped_body(loop_counter, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            outputs = body(
                *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
            if not nest.is_sequence(outputs):
                outputs = [outputs]
            # Compare the structure of input and output of body converting the
            # top-level tuples to list to be compatible with legacy while_loop.
            nest.assert_same_structure(list(outputs), list(orig_loop_vars))

            outputs = _tensor_array_to_flow(outputs)

            # Return the external_captures of cond_graph as is, i.e., treat them as
            # loop invariants.
            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1] + list(outputs) + list(
                args[len_orig_loop_vars:])

        body_graph = func_graph_module.func_graph_from_py_func(
            body_name,
            wrapped_body,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileBodyFuncGraph(body_name),
            add_control_dependencies=add_control_dependencies)
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture `external_captures` of `body_graph` in `cond_graph` so that it
        # expects to receive those as arguments.
        # TODO(b/118457764): Dedup tensors that are captured in both the cond and
        # body. This logic already exists in cond_v2.
        with cond_graph.as_default():
            for external_capture in body_graph.external_captures:
                assert external_capture not in cond_graph.captures, (
                    "Looks like both cond and body are capturing the same tensor %s. "
                    "This is not supported yet. For now consider passing,"
                    " this as a loop variable." % str(external_capture))
                cond_graph.capture(external_capture)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        num_flattened_outputs = len(nest.flatten(orig_loop_vars))
        _check_shapes_compat(
            body_graph.outputs[1:1 + num_flattened_outputs],
            nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
            nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
        flattened_loop_vars = nest.flatten(loop_vars)
        _check_num_inputs_outputs(cond_graph, body_graph,
                                  len(flattened_loop_vars))

        outputs = gen_functional_ops._while(
            flattened_loop_vars,
            util.create_new_tf_function(cond_graph),
            util.create_new_tf_function(body_graph),
            output_shapes=[t.shape for t in body_graph.outputs],
            name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        util.maybe_set_lowering_attr(outputs[0].op)
        _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

        # Return identities for each output of the While op, rather than the output
        # of the While op directly. This makes pruning work if the output of
        # while_loop() is fetched: the lowering pass converts the While outputs into
        # IdentityN outputs, which if fetched will cause all ops in the body to be
        # run (since it takes all exit ops as input). After lowering, each output
        # identity op will end up with only the appropriate exit op as input.
        outputs = tuple(array_ops.identity(t) for t in outputs)

    # First var is loop counter.
    outputs = _pack_sequence_as(orig_loop_vars,
                                outputs[1:1 + num_flattened_outputs])

    if return_same_structure:
        return outputs

    flattened_outputs = nest.flatten(outputs)
    if len(flattened_outputs) == 1:
        return flattened_outputs[0]
    else:
        return outputs
Exemplo n.º 25
0
  def wrap_function(self, fn, signature, name=None):
    """Wraps a TF 1.X function and returns an eager-compatible function.

    All functions wrapped in the same `WrappedGraph` will have access to the
    same graph (`tf.get_default_graph` to get the graph object within a
    function, or `WrappedGraph.graph` to get the graph outside a function).
    Variables created within the function will be added to the `variables` list.

    Function inputs: All inputs to the function must be tensors (nested ok),
    with their shapes and dtypes defined in the `signature` argument.

    Function outputs:

      * The 1.X function may return tensors, variables, and ops. The wrapped
        eager-compatible function will always return tensors in the same nested
        structure.
      * Variables are replaced with a tensor containing the latest read values.
      * Returned ops are executed, and replaced with None.
      * The order of op execution and variable reads in the return is
        nondeterministic. For example:

        ```
        def update_var(x):
          v = tf.Variable(0)
          op = tf.compat.v1.assign(v, x).op
          return v, op

        g = WrappedGraph()
        fn = g.wrap_function(update_var)
        read_value, _ = fn(tf.constant(3))
        print(read_value.numpy())  # could be 0 or 3
        print(g.variables[0].numpy()) # always 3
        ```

    To ensure that ops in the function are executed (e.g. ops added to the
    `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns.

    Args:
      fn: a 1.X tensorflow function.
      signature: a possibly nested sequence of `TensorSpecs` specifying the
        shapes and dtypes of the arguments.
      name: an optional string name for the function. The function will be saved
        with key `name` in the `functions` dictionary.

    Returns:
      An eager-compatible function.
    """
    fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
        self._variable_holder.call_with_variable_creator_scope(fn))

    func_graph.func_graph_from_py_func(
        None,  # Name is unused.
        fn_with_filter_and_scope,
        args=None, kwargs=None, signature=signature,
        add_control_dependencies=False,
        func_graph=self.graph)

    # This code relies on questional behavior from `func_graph_from_py_func`.
    # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
    # and structured outputs are overwritten. Pretty sure this is a bug,
    # because structured outputs doesn't match up with the outputs...
    fn_inputs = self.graph.inputs[:-len(self.graph.captures)]

    # Return filtered ops to the flattened outputs.
    flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
    for index, op in returned_ops.items():
      flat_fn_outputs[index] = op
    fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
                                       flat_fn_outputs)

    name = name or fn.__name__
    wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs, name)
    self._functions[name] = wrapped_function
    return wrapped_function
Exemplo n.º 26
0
def _create_grad_func(func_graph, grads, name):
  """Returns the FuncGraph representation of _grad_fn."""
  return func_graph_module.func_graph_from_py_func(
      name,
      lambda: _grad_fn(func_graph, grads), [], {},
      func_graph=_CondGradFuncGraph(name, func_graph))
Exemplo n.º 27
0
def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
                      maximum_iterations):
  """Builds and returns the gradient FuncGraph of `func_graph` and its args.

  The returned grad_func_graph must be called with the returned
  args + grad_func_graph.captures.

  Args:
    ys: A `Tensor` or list of tensors to be differentiated.
    xs: A `Tensor` or list of tensors to be used for differentiation.
    grads: The incoming grads for `ys`.
    cond_graph: FuncGraph for the forward cond function.
    body_graph: FuncGraph for the forward body function.
    name: Name of the returned gradient function.
    while_op: The forward While op.
    maximum_iterations: Tensor. The maximum number of iterations.

  Returns:
    2-tuple of (grad_func_graph, args).
  """
  assert len(ys) == len(grads)

  total_iters = while_op.outputs[0]
  counter = constant_op.constant(
      0, dtype=total_iters.dtype, name="grad_counter")

  # Build frozen sets so that we do not have linear time lookups in
  # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs`
  # may get updated during gradient computation because we add accumulators to
  # the forward op. However, those are not loop invariants so wouldn't affect
  # the output of `_is_loop_invariant`. Also we would never attempt to capture
  # those accumulators so `_is_loop_invariant` should never receive those new
  # tensors as args.
  body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs)
  body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs)

  args = [counter, maximum_iterations, total_iters] + list(grads)
  # Note: The returned function does not have `args` in the list of
  # `external_captures`.
  grad_func_graph = func_graph_module.func_graph_from_py_func(
      name,
      lambda *args: _grad_fn(ys, xs, args, body_graph),
      args, {},
      func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph,
                                         maximum_iterations, while_op,
                                         body_graph_inputs, body_graph_outputs))

  # Update the list of outputs with tensors corresponding to the captured
  # tensors. We capture 3 types of tensors when building the grad fn:
  # 1. Accumulators for forward graph intermediates which are not loop
  #    invariants. The outputs corresponding to these are populated in
  #    `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
  # 2. Resources, which are output as is.
  # 3. Forward graph loop invariants, which are output as is.
  for external_capture, internal_capture in grad_func_graph.captures:
    if ops.tensor_id(internal_capture) in grad_func_graph.popped_tensor_lists:
      new_output = grad_func_graph.popped_tensor_lists[ops.tensor_id(
          internal_capture)]
    elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
        external_capture, body_graph_inputs, body_graph_outputs)):
      new_output = internal_capture
    else:
      raise ValueError("Tensor %s which captures %s is in list of "
                       "internal_captures but is not a resource, is not in "
                       "popped_tensor_lists and does not capture a loop "
                       "invariant." %
                       (str(internal_capture), str(external_capture)))
    grad_func_graph.outputs.append(new_output)
    grad_func_graph.structured_outputs.append(new_output)

  return grad_func_graph, args
Exemplo n.º 28
0
def _create_grad_func(func_graph, grads, name):
    """Returns the FuncGraph representation of _grad_fn."""
    return func_graph_module.func_graph_from_py_func(
        name,
        lambda: _grad_fn(func_graph, grads), [], {},
        func_graph=_CondGradFuncGraph(name, func_graph))
Exemplo n.º 29
0
    def wrap_function(self, fn, signature, name=None):
        """Wraps a TF 1.X function and returns an eager-compatible function.

    All functions wrapped in the same `WrappedGraph` will have access to the
    same graph (`tf.get_default_graph` to get the graph object within a
    function, or `WrappedGraph.graph` to get the graph outside a function).
    Variables created within the function will be added to the `variables` list.

    Function inputs: All inputs to the function must be tensors (nested ok),
    with their shapes and dtypes defined in the `signature` argument.

    Function outputs:

      * The 1.X function may return tensors, variables, and ops. The wrapped
        eager-compatible function will always return tensors in the same nested
        structure.
      * Variables are replaced with a tensor containing the latest read values.
      * Returned ops are executed, and replaced with None.
      * The order of op execution and variable reads in the return is
        nondeterministic. For example:

        ```
        def update_var(x):
          v = tf.Variable(0)
          op = tf.compat.v1.assign(v, x).op
          return v, op

        g = WrappedGraph()
        fn = g.wrap_function(update_var)
        read_value, _ = fn(tf.constant(3))
        print(read_value.numpy())  # could be 0 or 3
        print(g.variables[0].numpy()) # always 3
        ```

    To ensure that ops in the function are executed (e.g. ops added to the
    `tf.GraphKeys.UPDATE_OPS` collection), include them in the function returns.

    Args:
      fn: a 1.X tensorflow function.
      signature: a possibly nested sequence of `TensorSpecs` specifying the
        shapes and dtypes of the arguments.
      name: an optional string name for the function. The function will be saved
        with key `name` in the `functions` dictionary.

    Returns:
      An eager-compatible function.
    """
        fn_with_filter_and_scope, returned_ops = _filter_returned_ops(
            self._variable_holder.call_with_variable_creator_scope(fn))

        func_graph.func_graph_from_py_func(
            None,  # Name is unused.
            fn_with_filter_and_scope,
            args=None,
            kwargs=None,
            signature=signature,
            add_control_dependencies=False,
            func_graph=self.graph)

        # This code relies on questional behavior from `func_graph_from_py_func`.
        # If an existing FuncGraph is passed into the `func_graph` arg, the inputs
        # and structured outputs are overwritten. Pretty sure this is a bug,
        # because structured outputs doesn't match up with the outputs...
        fn_inputs = self.graph.inputs[:-len(self.graph.captures)]

        # Return filtered ops to the flattened outputs.
        flat_fn_outputs = nest.flatten(self.graph.structured_outputs)
        for index, op in returned_ops.items():
            flat_fn_outputs[index] = op
        fn_outputs = nest.pack_sequence_as(self.graph.structured_outputs,
                                           flat_fn_outputs)

        name = name or fn.__name__
        wrapped_function = self._wrapped_function.prune(
            fn_inputs, fn_outputs, name)
        self._functions[name] = wrapped_function
        return wrapped_function
Exemplo n.º 30
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Set the incoming gradient of TensorArray handles to None. The gradient
  # implementation currently assumes all resource tensors correspond to float32
  # ResourceVariables, which can lead to runtime shape errors when used with a
  # TensorArray. This is a workaround until TensorArrays are reimplemented with
  # TensorLists instead of resources.
  # Also set the incoming gradient of non-trainable inputs to None. It is
  # possible that we receive non-None gradients for non-trainable types in
  # nested while loops because we accumulate outputs of the inner while as
  # variant tensors which are trainable and hence receive zeros_like tensors in
  # the gradient pass. The non-trainable tensors then receive the popped zeros
  # tensor from this zeros variant. The gradient for the loop vars corresponding
  # to these tensors is None or zeros (this happens only if the loop var is
  # accumulated as well) in _grad_fn so we reset these.
  # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
  # output grads in _grad_fn.
  grads = [
      None if _is_tensor_array_handle(output) or
      not gradients_impl.IsTrainable(output) else grad
      for grad, output in zip(grads, op.outputs)
  ]

  # Ensure that all non-resource trainable outputs have incoming gradients.
  assert all(g is not None or o.dtype == dtypes.resource or
             not gradients_impl.IsTrainable(o)
             for o, g in zip(op.outputs, grads)
            ), "All trainable loop vars must receive incoming gradients."
  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, body_graph,
      util.unique_grad_fn_name(body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  maximum_iterations = op.get_attr(
      "_maximum_iterations") if _is_in_xla_context() else None
  assert not _is_in_xla_context() or maximum_iterations is not None
  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=intermediate_tensor.shape,
        max_num_elements=maximum_iterations)

    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name="%s_grad" % op.name)

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(outputs[0].op)
  _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]

  # Set None as the output gradient for tensors with None input gradient
  # e.g. TensorArray handles.
  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  index = 2
  none_padded_outputs = []
  for g in grads:
    if g is None:
      none_padded_outputs.append(None)
    else:
      none_padded_outputs.append(outputs[index])
      index += 1
  return none_padded_outputs
Exemplo n.º 31
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
  """Like tf.while_loop, except emits a single While op."""
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants)
  else:
    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")

    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter] + loop_vars

    shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                              ]) + shape_invariants

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(loop_counter, *args):
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      if maximum_iterations is None:
        return cond(*_pack_sequence_as(orig_loop_vars, args))
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations,
            cond(*_pack_sequence_as(orig_loop_vars, args)))

    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        loop_vars, {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileCondFuncGraph(cond_name),
        add_control_dependencies=add_control_dependencies)

    # Add external_captures of cond to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + cond_graph.external_captures
    shape_invariants = shape_invariants + type(shape_invariants)(
        [t.shape for t in cond_graph.external_captures])

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(
          *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(
          args[len_orig_loop_vars:])

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        loop_vars, {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileBodyFuncGraph(body_name),
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
    # expects to receive those as arguments.
    # TODO(b/118457764): Dedup tensors that are captured in both the cond and
    # body. This logic already exists in cond_v2.
    with cond_graph.as_default():
      for external_capture in body_graph.external_captures:
        assert external_capture not in cond_graph.captures, (
            "Looks like both cond and body are capturing the same tensor %s. "
            "This is not supported yet. For now consider passing,"
            " this as a loop variable." % str(external_capture))
        cond_graph.capture(external_capture)

    # Export all tensors in the loop body that may be needed for gradient
    # computation. We do this by accumulating the intermediate values in
    # TensorLists.
    intermediate_tensors = _get_intermediates(body_graph)

    for intermediate_tensor in intermediate_tensors:
      tensor_list = list_ops.empty_tensor_list(
          element_dtype=intermediate_tensor.dtype,
          element_shape=intermediate_tensor.shape,
          max_num_elements=maximum_iterations)
      loop_vars.append(tensor_list)
      with cond_graph.as_default():
        # Add a placeholder to cond_graph's inputs corresponding to the
        # tensor_list.
        cond_graph.capture(tensor_list)
      with body_graph.as_default():
        # Push the intermediate tensor to the tensor list. This captures the
        # `tensor_list` as well.
        appended_tensor_list = list_ops.tensor_list_push_back(
            tensor_list,
            intermediate_tensor)
        # Add this modified tensor list to the list of outputs.
        body_graph.outputs.append(appended_tensor_list)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
    _check_shapes_compat(
        body_graph.outputs[1:1 + num_flattened_outputs],
        nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
        nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
    flattened_loop_vars = nest.flatten(loop_vars)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))

    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        util.create_new_tf_function(cond_graph),
        util.create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  # First var is loop counter.
  outputs = _pack_sequence_as(orig_loop_vars,
                              outputs[1:1 + num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Exemplo n.º 32
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # Set the flag to enable lowering on the `if` op if necessary
    # Lowering allows cond_v2 to avoid some of the limitations of Functions,
    # allowing users to specify devices & colocation inside of cond_v2 branches,
    # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
    # This brings cond_v2 closer to feature parity with tf.cond.
    #
    # However, we do not lower `If` in the XLA context because it is easier for
    # XLA to apply its own optimizations when dealing with un-lowered `If`
    # operators than with lowered switch/merge control flow.
    #
    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if not control_flow_util.IsInXLAContext(if_op):
      # pylint: disable=protected-access
      if_op._set_attr("_lower_using_switch_merge",
                      attr_value_pb2.AttrValue(b=True))
      # pylint: enable=protected-access

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    result = tuple(tensors[:num_cond_outputs])
    if len(result) == 1:
      return result[0]
    else:
      return result
Exemplo n.º 33
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
    """Like tf.while_loop, except emits a single While op."""
    # Keep the original loop_vars around to know which args were TensorArrays.
    orig_loop_vars = loop_vars
    # Cache its length since we use it at multiple places below.
    len_orig_loop_vars = len(orig_loop_vars)

    # Convert TensorArrays to their flow variables. These get converted back to
    # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
    # `wrapped_body` below.
    loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
    loop_vars = nest.map_structure(
        ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(orig_loop_vars, shape_invariants)
    else:
        shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")
        maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
            maximum_iterations)
        loop_counter = constant_op.constant(
            0,
            dtype=maximum_iterations_loop_var.dtype
            if maximum_iterations is not None else None,
            name="loop_counter")
        # Add loop counter needed for computing gradients.
        loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

        shape_invariants = type(shape_invariants)(
            [tensor_shape.scalar(),
             tensor_shape.scalar()]) + shape_invariants

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            if maximum_iterations is None:
                return cond(*_pack_sequence_as(orig_loop_vars, args))
            else:
                return math_ops.logical_and(
                    loop_counter < maximum_iterations_arg,
                    cond(*_pack_sequence_as(orig_loop_vars, args)))

        # NOTE(skyewm): we set collections to the outer graph's collections for
        # compatibility with TPUEstimator.
        cond_graph = func_graph_module.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            [],  # We provide signature instead of args.
            {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileCondFuncGraph(
                cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies)

        def wrapped_body(loop_counter, maximum_iterations_arg, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
            # Capture the tensors already captured in cond_graph so that they appear
            # in the same order in body_graph.external_captures.
            for t in cond_graph.external_captures:
                ops.get_default_graph().capture(t)

            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            outputs = body(*_pack_sequence_as(orig_loop_vars, args))
            if not nest.is_sequence(outputs):
                outputs = [outputs]
            # Compare the structure of input and output of body converting the
            # top-level tuples to list to be compatible with legacy while_loop.
            nest.assert_same_structure(list(outputs), list(orig_loop_vars))

            outputs = _tensor_array_to_flow(outputs)

            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

        body_graph = func_graph_module.func_graph_from_py_func(
            body_name,
            wrapped_body,
            [],  # We provide signature instead of args.
            {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileBodyFuncGraph(
                body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies)
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
        # that it expects to receive those as arguments.
        with cond_graph.as_default():
            num_cond_captures = len(cond_graph.external_captures)
            assert (cond_graph.external_captures ==
                    body_graph.external_captures[:num_cond_captures])
            for body_capture in body_graph.external_captures[
                    num_cond_captures:]:
                assert body_capture not in cond_graph.captures
                cond_graph.capture(body_capture)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        num_flattened_outputs = len(nest.flatten(orig_loop_vars))
        # First var is loop counter and second var is maximum_iterations.
        first_loop_var_index = 2
        _check_shapes_compat(
            body_graph.outputs[first_loop_var_index:first_loop_var_index +
                               num_flattened_outputs],
            nest.flatten(
                shape_invariants[first_loop_var_index:first_loop_var_index +
                                 len_orig_loop_vars]),
            nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                                   len_orig_loop_vars]))
        flattened_loop_vars = nest.flatten(loop_vars)
        _check_num_inputs_outputs(cond_graph, body_graph,
                                  len(flattened_loop_vars))

        with ops.control_dependencies(
                list(cond_graph.control_captures) +
                list(body_graph.control_captures)):
            outputs = gen_functional_ops._while(
                flattened_loop_vars,
                util.create_new_tf_function(cond_graph),
                util.create_new_tf_function(body_graph),
                output_shapes=[t.shape for t in body_graph.outputs],
                parallel_iterations=parallel_iterations,
                name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        util.maybe_set_lowering_attr(outputs[0].op)
        util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

        # Return identities for each output of the While op, rather than the output
        # of the While op directly. This makes pruning work if the output of
        # while_loop() is fetched: the lowering pass converts the While outputs into
        # IdentityN outputs, which if fetched will cause all ops in the body to be
        # run (since it takes all exit ops as input). After lowering, each output
        # identity op will end up with only the appropriate exit op as input.
        outputs = tuple(array_ops.identity(t) for t in outputs)

    outputs = _pack_sequence_as(
        orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                                num_flattened_outputs])

    if return_same_structure:
        return outputs

    flattened_outputs = nest.flatten(outputs)
    if len(flattened_outputs) == 1:
        return flattened_outputs[0]
    else:
        return outputs
Exemplo n.º 34
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
  """Like tf.while_loop, except emits a single While op."""
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants)
  else:
    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")

    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter] + loop_vars

    shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                              ]) + shape_invariants

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(loop_counter, *args):
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      if maximum_iterations is None:
        return cond(*_pack_sequence_as(orig_loop_vars, args))
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations,
            cond(*_pack_sequence_as(orig_loop_vars, args)))

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph.captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
    _check_shapes_compat(
        body_graph.outputs[1:1 + num_flattened_outputs],
        nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
        nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
    flattened_loop_vars = nest.flatten(loop_vars)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))

    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        util.create_new_tf_function(cond_graph),
        util.create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        parallel_iterations=parallel_iterations,
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  # First var is loop counter.
  outputs = _pack_sequence_as(orig_loop_vars,
                              outputs[1:1 + num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Exemplo n.º 35
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  # Note that op is not always the same as while_op because the gradient tape,
  # for eager mode compatibility, forgets information about the proper op. Since
  # the loop cannot run in eager mode, however, we can safely introspect into
  # the graph here.
  while_op = op.outputs[0].op
  cond_graph = _get_graph(while_op, "cond")
  body_graph = _get_graph(while_op, "body")
  orig_num_params = len(body_graph.outputs)

  maximum_iterations = op.inputs[1]
  parallel_iterations = op.get_attr("parallel_iterations")

  try:
    num_original_outputs = while_op.get_attr("_num_original_outputs")
  except:  # pylint: disable=bare-except
    num_original_outputs = len(while_op.outputs)

  num_intermediates = len(while_op.outputs) - num_original_outputs
  grads = [
      _preprocess_grad(grad, body_out, while_out)  # pylint: disable=g-complex-comprehension
      for grad, body_out, while_out in zip(
          grads[:num_original_outputs],
          body_graph.outputs[:num_original_outputs],
          while_op.outputs[:num_original_outputs])
  ] + [None] * num_intermediates

  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, cond_graph, body_graph,
      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

  if body_grad_graph.while_op_needs_rewrite:
    # Modify 'op' to output the intermediate accumulators needed by the grad
    # function.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    cond_graph.name += "_rewritten"
    body_graph.name += "_rewritten"

    new_inputs = body_grad_graph.empty_tensor_lists
    new_outputs = body_graph.outputs[orig_num_params:]

    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
    while_op._set_type_list_attr("T", body_graph.output_types)
    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
    while_op._add_while_inputs(new_inputs)
    while_op._add_outputs([t.dtype for t in new_outputs],
                          [t.shape for t in new_outputs])
    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

  # Do not ingore grads wrt extra outputs when computing higher order
  # derivatives.
  while_op._set_attr("_num_original_outputs",
                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))

  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                           while_op)
  loop_vars = args + captured_inputs

  # This modifies body_grad_graph.
  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
      grads, body_grad_graph, loop_vars, while_op.inputs)

  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
                *unused_args):
    return counter < forward_loop_iters

  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      parallel_iterations=parallel_iterations,
      name="%s_grad" % while_op.name)
  grad_op = outputs[0].op

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(grad_op)
  util.maybe_propagate_compile_time_consts_in_xla(grad_op)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]
  return _get_structured_grad_output(outputs, grads, body_grad_graph)
Exemplo n.º 36
0
def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
  """Like tf.while_loop, except emits a single While op."""
  flattened_loop_vars = nest.flatten(loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(loop_vars, shape_invariants)
    flattened_shapes = nest.flatten(shape_invariants)
  else:
    flattened_shapes = [t.shape for t in flattened_loop_vars]

  del shape_invariants

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")

    num_outputs = len(flattened_loop_vars)

    # Add loop counter needed for computing gradients.
    flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
                          ] + flattened_loop_vars

    flattened_shapes = [tensor_shape.scalar()] + flattened_shapes

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(unused_loop_counter, *loop_vars):
      return cond(*loop_vars)

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature,
        func_graph=util.WhileCondFuncGraph(cond_name),
        add_control_dependencies=add_control_dependencies)

    # Add external_captures of cond to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
    flattened_shapes = flattened_shapes + [
        t.shape for t in cond_graph.external_captures
    ]

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:num_outputs] - Args for the original loop body.
          args[num_outputs:] - External captures of cond. These get passed
            through as is.

      Returns:
        A list of tensors the same length as args.
      """
      outputs = body(*args[:num_outputs])
      if not isinstance(outputs, collections.Sequence):
        outputs = [outputs]

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])

    signature = [
        tensor_spec.TensorSpec(shape, t.dtype)
        for shape, t in zip(flattened_shapes, flattened_loop_vars)
    ]
    body_graph = func_graph_module.func_graph_from_py_func(
        body_name, wrapped_body, flattened_loop_vars, {}, signature=signature,
        func_graph=util.WhileBodyFuncGraph(body_name),
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
    # expects to receive those as arguments.
    # TODO(srbs): Dedup tensors that are captured in both the cond and body.
    # This logic already exists in cond_v2.
    with cond_graph.as_default():
      for external_capture in body_graph.external_captures:
        cond_graph.capture(external_capture)

    # Export all tensors in the loop body that may be needed for gradient
    # computation. We do this by accumulating the intermediate values in
    # TensorLists.
    intermediate_tensors = _get_intermediates(body_graph)

    for intermediate_tensor in intermediate_tensors:
      # TODO(srbs): Cache and re-use empty tensor lists.
      tensor_list = list_ops.empty_tensor_list(
          element_dtype=intermediate_tensor.dtype,
          element_shape=_get_tensor_convertible_shape(
              intermediate_tensor.shape))
      flattened_loop_vars.append(tensor_list)
      with cond_graph.as_default():
        # Add a placeholder to cond_graph's inputs corresponding to the
        # tensor_list.
        cond_graph.capture(tensor_list)
      with body_graph.as_default():
        # Push the intermediate tensor to the tensor list. This captures the
        # `tensor_list` as well.
        appended_tensor_list = list_ops.tensor_list_push_back(
            tensor_list,
            intermediate_tensor)
        # Add this modified tensor list to the list of outputs.
        body_graph.outputs.append(appended_tensor_list)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
                         flattened_shapes[1:1 + num_outputs],
                         flattened_loop_vars[1:1 + num_outputs])
    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        util.create_new_tf_function(cond_graph),
        util.create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    _maybe_set_lowering_attr(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  # First var is loop counter.
  if num_outputs == 1:
    return outputs[1]
  else:
    return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
Exemplo n.º 37
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Exemplo n.º 38
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    body_graph = _get_body_graph(op)

    # Set the incoming gradient of TensorArray handles to None. The gradient
    # implementation currently assumes all resource tensors correspond to float32
    # ResourceVariables, which can lead to runtime shape errors when used with a
    # TensorArray. This is a workaround until TensorArrays are reimplemented with
    # TensorLists instead of resources.
    # Also set the incoming gradient of non-trainable inputs to None. It is
    # possible that we receive non-None gradients for non-trainable types in
    # nested while loops because we accumulate outputs of the inner while as
    # variant tensors which are trainable and hence receive zeros_like tensors in
    # the gradient pass. The non-trainable tensors then receive the popped zeros
    # tensor from this zeros variant. The gradient for the loop vars corresponding
    # to these tensors is None or zeros (this happens only if the loop var is
    # accumulated as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if _is_tensor_array_handle(output)
        or not gradients_impl.IsTrainable(output) else grad
        for grad, output in zip(grads, op.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(g is not None or o.dtype == dtypes.resource
               or not gradients_impl.IsTrainable(o)
               for o, g in zip(op.outputs, grads)
               ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, body_graph,
        util.unique_grad_fn_name(body_graph.name), op)

    intermediate_tensors = _get_intermediates(body_grad_graph)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None
    for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)

        with body_grad_graph.as_default():
            tensor_list_ph = body_grad_graph.capture(tensor_list,
                                                     whitelisted=True)
            # Push the intermediate tensor to the tensor list.
            appended_tensor_list = list_ops.tensor_list_push_back(
                tensor_list_ph, intermediate_tensor)
            # Add this modified tensor list to the list of outputs.
            body_grad_graph.outputs.append(appended_tensor_list)

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    loop_vars = args + body_grad_graph.external_captures
    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
Exemplo n.º 39
0
def wrap_function(fn, signature, name=None):
    """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
  TensorFlow graph. But while `tf.function` runs all stateful operations
  (e.g. `tf.print`) and sequences operations to provide the same semantics as
  eager execution, `wrap_function` is closer to the behavior of `session.run` in
  TensorFlow 1.x. It will not run any operations unless they are required to
  compute the function's outputs, either through a data dependency or a control
  dependency. Nor will it sequence operations.

  Unlike `tf.function`, `wrap_function` will only trace the Python function
  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
  `wrap_function`'s `signature` argument.

  Since it is only traced once, variables and state may be created inside the
  function and owned by the function wrapper object.

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the wrapped
      function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
    holder = VariableHolder(fn)
    func_graph_name = "wrapped_function"
    if name is not None:
        func_graph_name = "wrapped_function_" + name
    return WrappedFunction(func_graph.func_graph_from_py_func(
        func_graph_name,
        holder,
        args=None,
        kwargs=None,
        signature=signature,
        add_control_dependencies=False,
        collections={}),
                           variable_holder=holder,
                           signature=signature)
Exemplo n.º 40
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    # Note that op is not always the same as while_op because the gradient tape,
    # for eager mode compatibility, forgets information about the proper op. Since
    # the loop cannot run in eager mode, however, we can safely introspect into
    # the graph here.
    while_op = op.outputs[0].op
    cond_graph = _get_graph(while_op, "cond")
    body_graph = _get_graph(while_op, "body")
    orig_num_params = len(body_graph.outputs)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None
    maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)

    # Set the incoming gradient of non-trainable inputs to None. It is possible
    # that we receive non-None gradients for non-trainable types in nested while
    # loops because we accumulate outputs of the inner while as variant tensors
    # which are trainable and hence receive zeros_like tensors in the gradient
    # pass. The non-trainable tensors then receive the popped zeros tensor from
    # this zeros variant. The gradient for the loop vars corresponding to these
    # tensors is None or zeros (this happens only if the loop var is accumulated
    # as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if not _is_trainable(output) else grad
        for grad, output in zip(grads, body_graph.outputs)
    ]

    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, cond_graph, body_graph,
        util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

    if body_grad_graph.while_op_needs_rewrite:
        # Modify 'op' to output the intermediate accumulators needed by the grad
        # function.
        # NOTE(skyewm): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        cond_graph.name += "_rewritten"
        body_graph.name += "_rewritten"

        new_inputs = body_grad_graph.empty_tensor_lists
        new_outputs = body_graph.outputs[orig_num_params:]

        while_op._set_func_attr("cond",
                                util.create_new_tf_function(cond_graph))
        while_op._set_func_attr("body",
                                util.create_new_tf_function(body_graph))
        while_op._set_type_list_attr("T", body_graph.output_types)
        while_op._set_shape_list_attr("output_shapes",
                                      body_graph.output_shapes)
        while_op._add_while_inputs(new_inputs)
        while_op._add_outputs([t.dtype for t in new_outputs],
                              [t.shape for t in new_outputs])
        _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

    captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                             while_op)
    loop_vars = args + captured_inputs

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % while_op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
Exemplo n.º 41
0
def _create_grad_func(func_graph, grads, name):
  """Returns the FuncGraph representation of _grad_fn."""
  return func_graph_module.func_graph_from_py_func(
      name,
      lambda: _grad_fn(func_graph, grads), [], {},
      func_graph=util.CondBranchFuncGraph(name, read_only_collections=False))
Exemplo n.º 42
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    body_graph = _get_body_graph(op)

    # Set the incoming gradient of TensorArray handle to None.
    # TODO(b/118164915): We need a way of distinguising b/w TensorArray resource
    # handles and ResourceVariables and set the default gradient of only the
    # TensorArray handle to None.
    grads = [
        None if output.dtype == dtypes.resource else g
        for g, output in zip(grads, op.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(g is not None or o.dtype == dtypes.resource
               or not gradients_impl.IsTrainable(o)
               for o, g in zip(op.outputs, grads)
               ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, body_graph,
        util.unique_grad_fn_name(body_graph.name), op)

    intermediate_tensors = _get_intermediates(body_grad_graph)

    for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=_get_tensor_convertible_shape(
                intermediate_tensor.shape))
        with body_grad_graph.as_default():
            tensor_list_ph = body_grad_graph.capture(tensor_list,
                                                     whitelisted=True)
            # Push the intermediate tensor to the tensor list.
            appended_tensor_list = list_ops.tensor_list_push_back(
                tensor_list_ph, intermediate_tensor)
            # Add this modified tensor list to the list of outputs.
            body_grad_graph.outputs.append(appended_tensor_list)

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    loop_vars = args + body_grad_graph.external_captures
    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    _maybe_set_lowering_attr(outputs[0].op)

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
Exemplo n.º 43
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    util.maybe_set_lowering_attr(tensors[0].op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
Exemplo n.º 44
0
def wrap_function(fn, signature, name=None):
  """Wraps the TF 1.x function fn into a graph function.

  The python function `fn` will be called once with symbolic arguments specified
  in the `signature`, traced, and turned into a graph function. Any variables
  created by `fn` will be owned by the object returned by `wrap_function`. The
  resulting graph function can be called with tensors which match the
  signature.

  ```python
  def f(x, do_add):
    v = tf.Variable(5.0)
    if do_add:
      op = v.assign_add(x)
    else:
      op = v.assign_sub(x)
    with tf.control_dependencies([op]):
      return v.read_value()

  f_add = tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), True])

  assert float(f_add(1.0)) == 6.0
  assert float(f_add(1.0)) == 7.0

  # Can call tf.compat.v1.wrap_function again to get a new trace, a new set
  # of variables, and possibly different non-template arguments.
  f_sub= tf.compat.v1.wrap_function(f, [tf.TensorSpec((), tf.float32), False])

  assert float(f_sub(1.0)) == 4.0
  assert float(f_sub(1.0)) == 3.0
  ```

  Both `tf.compat.v1.wrap_function` and `tf.function` create a callable
  TensorFlow graph. But while `tf.function` runs all stateful operations
  (e.g. `tf.print`) and sequences operations to provide the same semantics as
  eager execution, `wrap_function` is closer to the behavior of `session.run` in
  TensorFlow 1.x. It will not run any operations unless they are required to
  compute the function's outputs, either through a data dependency or a control
  dependency. Nor will it sequence operations.

  Unlike `tf.function`, `wrap_function` will only trace the Python function
  once. As with placeholders in TF 1.x, shapes and dtypes must be provided to
  `wrap_function`'s `signature` argument.

  Since it is only traced once, variables and state may be created inside the
  function and owned by the function wrapper object.

  Args:
    fn: python function to be wrapped
    signature: the placeholder and python arguments to be passed to the
      wrapped function
    name: Optional. The name of the function.

  Returns:
    the wrapped graph function.
  """
  holder = VariableHolder(fn)
  func_graph_name = "wrapped_function"
  if name is not None:
    func_graph_name = "wrapped_function_" + name
  return WrappedFunction(
      func_graph.func_graph_from_py_func(
          func_graph_name,
          holder,
          args=None, kwargs=None, signature=signature,
          add_control_dependencies=False,
          collections={}),
      variable_holder=holder,
      signature=signature)