示例#1
0
 def _make_op(inputs):
   if_op, tensors = util.get_op_and_outputs(op_fn(
       pred,
       inputs, [t.dtype for t in true_graph.outputs],
       util.create_new_tf_function(true_graph),
       util.create_new_tf_function(false_graph),
       output_shapes=_get_output_shapes(true_graph.outputs,
                                        false_graph.outputs),
       name=name))
   _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
   # `if_op` is None if this is a `StatelessIf` op with no outputs.
   if if_op is not None:
     # The true and false graphs have already been created, and we need that
     # to happen before we know which tensors will be captured and so whether
     # to wrap the cond in a tf.function. Post-hoc mutation of the branch
     # `outer_graph` properties seems like the only option if we want to
     # conditionally wrap in a function.
     true_graph.outer_graph = ops.get_default_graph()
     false_graph.outer_graph = ops.get_default_graph()
     if_op._true_graph = true_graph
     if_op._false_graph = false_graph
     util.maybe_set_lowering_attr(if_op)
     util.maybe_propagate_compile_time_consts_in_xla(if_op)
     _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
     # Prevent fetching since the variant outputs can't be fetched directly.
     if_op.graph.prevent_fetching(if_op)
   return tensors
示例#2
0
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
    """Creates an `Case` op from `branch_index`, branch graphs and inputs.

  Note that this modifies `branch_graphs` to make the inputs match, and to
  output all intermediates values so they're available for the gradient
  computation.

  `branch_graphs` need not have the same input types, but they must
  have the same outpute types.

  Args:
    branch_index: integer Tensor
    branch_graphs: List of FuncGraph
    branch_inputs: List of lists of Tensors to be passed to corresponding
      branch_graph as input.
    name: the name for the Case op.

  Returns:
    A list of Tensors which are the outputs of the Case op. Does not include
    added intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
    _check_same_outputs(_CASE, branch_graphs)

    # Add inputs to branch_graphs to make them match. Note that this modifies the
    # graphs in `branch_graphs`.
    case_inputs = _make_inputs_match(branch_graphs, branch_inputs)

    # Create the Case op.
    with ops.control_dependencies(
            sum((list(bg.control_captures) for bg in branch_graphs), [])):
        tensors = gen_functional_ops.case(
            branch_index,
            case_inputs, [t.dtype for t in branch_graphs[0].outputs],
            [util.create_new_tf_function(g) for g in branch_graphs],
            output_shapes=_get_output_shapes(
                *[g.outputs for g in branch_graphs]),
            name=name)

    case_op, tensors = _get_op_and_outputs(tensors)

    if case_op is not None:
        util.maybe_set_lowering_attr(case_op)
        util.maybe_propagate_compile_time_consts_in_xla(case_op)
        _set_read_only_resource_inputs_attr(case_op, branch_graphs)
        # Prevent fetching since the variant outputs can't be fetched directly.
        case_op.graph.prevent_fetching(case_op)

    # Return identities for each output of the Case op, rather than the output of
    # the Case op directly. This makes pruning work if the output of switch_case()
    # is fetched: the lowering pass converts the Case outputs into IdentityN
    # outputs, which if fetched will cause all ops in the taken branch to be run
    # (since it takes all merge ops as input). After lowering, each output
    # identity op will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors)
示例#3
0
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _check_same_outputs(true_graph, false_graph)

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match(true_graph, false_graph,
                                   true_inputs, false_inputs)

  # Create the If op.
  tensors = gen_functional_ops._if(  # pylint: disable=protected-access
      pred,
      cond_inputs, [t.dtype for t in true_graph.outputs],
      util.create_new_tf_function(true_graph),
      util.create_new_tf_function(false_graph),
      output_shapes=_get_output_shapes(true_graph.outputs,
                                       false_graph.outputs),
      name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)
  return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                            tensors)
示例#4
0
 def _make_op(inputs):
   case_op, tensors = util.get_op_and_outputs(op_fn(
       branch_index,
       inputs, [t.dtype for t in branch_graphs[0].outputs],
       [util.create_new_tf_function(g) for g in branch_graphs],
       output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]),
       name=name))
   _copy_handle_data(tensors, *[g.outputs for g in branch_graphs])
   if case_op is not None:
     util.maybe_set_lowering_attr(case_op, lower_using_switch_merge)
     util.maybe_propagate_compile_time_consts_in_xla(case_op)
     _set_read_only_resource_inputs_attr(case_op, branch_graphs)
     # Prevent fetching since the variant outputs can't be fetched directly.
     case_op.graph.prevent_fetching(case_op)
   return tensors
示例#5
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    # Note that op is not always the same as while_op because the gradient tape,
    # for eager mode compatibility, forgets information about the proper op. Since
    # the loop cannot run in eager mode, however, we can safely introspect into
    # the graph here.
    while_op = op.outputs[0].op
    cond_graph = _get_graph(while_op, "cond")
    body_graph = _get_graph(while_op, "body")
    orig_num_params = len(body_graph.outputs)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None
    maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)

    # Set the incoming gradient of non-trainable inputs to None. It is possible
    # that we receive non-None gradients for non-trainable types in nested while
    # loops because we accumulate outputs of the inner while as variant tensors
    # which are trainable and hence receive zeros_like tensors in the gradient
    # pass. The non-trainable tensors then receive the popped zeros tensor from
    # this zeros variant. The gradient for the loop vars corresponding to these
    # tensors is None or zeros (this happens only if the loop var is accumulated
    # as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if not _is_trainable(output) else grad
        for grad, output in zip(grads, body_graph.outputs)
    ]

    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, cond_graph, body_graph,
        util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

    if body_grad_graph.while_op_needs_rewrite:
        # Modify 'op' to output the intermediate accumulators needed by the grad
        # function.
        # NOTE(skyewm): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        cond_graph.name += "_rewritten"
        body_graph.name += "_rewritten"

        new_inputs = body_grad_graph.empty_tensor_lists
        new_outputs = body_graph.outputs[orig_num_params:]

        while_op._set_func_attr("cond",
                                util.create_new_tf_function(cond_graph))
        while_op._set_func_attr("body",
                                util.create_new_tf_function(body_graph))
        while_op._set_type_list_attr("T", body_graph.output_types)
        while_op._set_shape_list_attr("output_shapes",
                                      body_graph.output_shapes)
        while_op._add_while_inputs(new_inputs)
        while_op._add_outputs([t.dtype for t in new_outputs],
                              [t.shape for t in new_outputs])
        _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

    captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                             while_op)
    loop_vars = args + captured_inputs

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % while_op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
示例#6
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    body_graph = _get_body_graph(op)

    # Set the incoming gradient of TensorArray handles to None. The gradient
    # implementation currently assumes all resource tensors correspond to float32
    # ResourceVariables, which can lead to runtime shape errors when used with a
    # TensorArray. This is a workaround until TensorArrays are reimplemented with
    # TensorLists instead of resources.
    # Also set the incoming gradient of non-trainable inputs to None. It is
    # possible that we receive non-None gradients for non-trainable types in
    # nested while loops because we accumulate outputs of the inner while as
    # variant tensors which are trainable and hence receive zeros_like tensors in
    # the gradient pass. The non-trainable tensors then receive the popped zeros
    # tensor from this zeros variant. The gradient for the loop vars corresponding
    # to these tensors is None or zeros (this happens only if the loop var is
    # accumulated as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if _is_tensor_array_handle(output)
        or not gradients_impl.IsTrainable(output) else grad
        for grad, output in zip(grads, op.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(g is not None or o.dtype == dtypes.resource
               or not gradients_impl.IsTrainable(o)
               for o, g in zip(op.outputs, grads)
               ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, body_graph,
        util.unique_grad_fn_name(body_graph.name), op)

    intermediate_tensors = _get_intermediates(body_grad_graph)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None
    for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)

        with body_grad_graph.as_default():
            tensor_list_ph = body_grad_graph.capture(tensor_list,
                                                     whitelisted=True)
            # Push the intermediate tensor to the tensor list.
            appended_tensor_list = list_ops.tensor_list_push_back(
                tensor_list_ph, intermediate_tensor)
            # Add this modified tensor list to the list of outputs.
            body_grad_graph.outputs.append(appended_tensor_list)

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    loop_vars = args + body_grad_graph.external_captures
    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
示例#7
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    util.maybe_set_lowering_attr(tensors[0].op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
示例#8
0
def _IfGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of an If op produced by cond_v2."""
  true_graph, false_graph = _get_func_graphs(op)
  # Note: op.graph != ops.get_default_graph() when we are computing the gradient
  # of a nested cond.
  assert true_graph.outer_graph == op.graph
  assert false_graph.outer_graph == op.graph

  # Create grad functions that compute the gradient of the true/false forward
  # graphs. These functions will capture tensors from the forward pass
  # functions.
  true_grad_graph = _create_grad_func(
      true_graph, grads, util.unique_grad_fn_name(true_graph.name))
  false_grad_graph = _create_grad_func(
      false_graph, grads, util.unique_grad_fn_name(false_graph.name))

  assert ([t.dtype for t in true_grad_graph.outputs] ==
          [t.dtype for t in false_grad_graph.outputs])

  # Resolve references to forward graph tensors in grad graphs and ensure
  # they are in-scope, i.e., belong to one of outer graphs of the grad graph.
  true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
  false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)

  # Make the inputs to true_grad_graph and false_grad_graph match. Note that
  # this modifies true_grad_graph and false_grad_graph.
  grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
                                   true_grad_inputs, false_grad_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # higher-order gradient computations.

  true_grad_intermediates = _get_intermediates(true_grad_graph)
  false_grad_intermediates = _get_intermediates(false_grad_graph)

  # Save the original number of gradient outputs to return.
  num_grad_outputs = len(true_grad_graph.outputs)

  # Make the number/type of new intermediate outputs match.
  extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
      true_grad_graph, false_grad_graph,
      true_grad_intermediates, false_grad_intermediates)

  true_grad_graph.outputs.extend(extra_true_grad_outputs)
  false_grad_graph.outputs.extend(extra_false_grad_outputs)

  # Create the gradient If op.
  tensors = gen_functional_ops._if(
      op.inputs[0],
      grad_inputs, [t.dtype for t in true_grad_graph.outputs],
      util.create_new_tf_function(true_grad_graph),
      util.create_new_tf_function(false_grad_graph),
      output_shapes=_get_output_shapes(true_grad_graph.outputs,
                                       false_grad_graph.outputs))

  util.maybe_set_lowering_attr(tensors[0].op)

  # See comment in cond_v2.
  tensors = [array_ops.identity(t) for t in tensors]

  # The predicate has no gradient.
  return [None] + tensors[:num_grad_outputs]
示例#9
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
    """The gradient of a While op produced by while_loop."""
    cond_graph = _get_graph(op, "cond")
    body_graph = _get_graph(op, "body")
    orig_num_params = len(body_graph.outputs)

    maximum_iterations = op.get_attr(
        "_maximum_iterations") if _is_in_xla_context() else None
    assert not _is_in_xla_context() or maximum_iterations is not None

    # Set the incoming gradient of TensorArray handles to None. The gradient
    # implementation currently assumes all resource tensors correspond to float32
    # ResourceVariables, which can lead to runtime shape errors when used with a
    # TensorArray. This is a workaround until TensorArrays are reimplemented with
    # TensorLists instead of resources.
    # Also set the incoming gradient of non-trainable inputs to None. It is
    # possible that we receive non-None gradients for non-trainable types in
    # nested while loops because we accumulate outputs of the inner while as
    # variant tensors which are trainable and hence receive zeros_like tensors in
    # the gradient pass. The non-trainable tensors then receive the popped zeros
    # tensor from this zeros variant. The gradient for the loop vars corresponding
    # to these tensors is None or zeros (this happens only if the loop var is
    # accumulated as well) in _grad_fn so we reset these.
    # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
    # output grads in _grad_fn.
    grads = [
        None if _is_tensor_array_handle(output) or not _is_trainable(output)
        else grad for grad, output in zip(grads, body_graph.outputs)
    ]

    # Ensure that all non-resource trainable outputs have incoming gradients.
    assert all(
        g is not None or o.dtype == dtypes.resource or not _is_trainable(o)
        for o, g in zip(body_graph.outputs, grads)
    ), "All trainable loop vars must receive incoming gradients."
    # We compute the gradient for the sub-graph between trainable ys and xs
    # with non-None incoming gradients. We later pad the None's to the list of
    # outputs.
    ys, xs, non_none_grads = zip(
        *[(y, x, grad)
          for (y, x, grad) in zip(body_graph.outputs, body_graph.inputs, grads)
          if grad is not None])

    body_grad_graph, args = _create_grad_func(
        ys, xs, non_none_grads, cond_graph, body_graph,
        util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

    if body_grad_graph.while_op_needs_rewrite:
        # Modify 'op' to output the intermediate accumulators needed by the grad
        # function.
        # NOTE(skyewm): if there are any active sessions, this modification to `op`
        # may make them unrunnable!

        cond_graph.name += "_rewritten"
        body_graph.name += "_rewritten"

        new_inputs = body_grad_graph.empty_tensor_lists
        new_outputs = body_graph.outputs[orig_num_params:]

        op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
        op._set_func_attr("body", util.create_new_tf_function(body_graph))
        op._set_type_list_attr("T", body_graph.output_types)
        op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
        op._add_while_inputs(new_inputs)
        op._add_outputs([t.dtype for t in new_outputs],
                        [t.shape for t in new_outputs])
        _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

    captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, op)
    loop_vars = args + captured_inputs

    def grad_cond(counter, max_iters, *unused_args):
        return counter < max_iters

    grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
    cond_grad_graph = func_graph_module.func_graph_from_py_func(
        grad_cond_name,
        grad_cond,
        loop_vars, {},
        func_graph=util.WhileCondFuncGraph(grad_cond_name))

    _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

    outputs = gen_functional_ops._while(
        loop_vars,
        util.create_new_tf_function(cond_grad_graph),
        util.create_new_tf_function(body_grad_graph),
        output_shapes=[t.shape for t in body_grad_graph.outputs],
        name="%s_grad" % op.name)

    _copy_handle_data(body_grad_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # See comment in while_loop.
    outputs = [array_ops.identity(t) for t in outputs]

    # Set None as the output gradient for tensors with None input gradient
    # e.g. TensorArray handles.
    # outputs[0] is the loop counter.
    # outputs[1] is the total number of loop iterations.
    index = 2
    none_padded_outputs = []
    for g in grads:
        if g is None:
            none_padded_outputs.append(None)
        else:
            none_padded_outputs.append(outputs[index])
            index += 1
    return none_padded_outputs
示例#10
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
示例#11
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
  """Like tf.while_loop, except emits a single While op."""
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants)
  else:
    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")

    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter] + loop_vars

    shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                              ]) + shape_invariants

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(loop_counter, *args):
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      if maximum_iterations is None:
        return cond(*_pack_sequence_as(orig_loop_vars, args))
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations,
            cond(*_pack_sequence_as(orig_loop_vars, args)))

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph.captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
    _check_shapes_compat(
        body_graph.outputs[1:1 + num_flattened_outputs],
        nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
        nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
    flattened_loop_vars = nest.flatten(loop_vars)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))

    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        util.create_new_tf_function(cond_graph),
        util.create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        parallel_iterations=parallel_iterations,
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  # First var is loop counter.
  outputs = _pack_sequence_as(orig_loop_vars,
                              outputs[1:1 + num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
示例#12
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
  """Like tf.while_loop, except emits a single While op."""
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants)
  else:
    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")

    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter] + loop_vars

    shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                              ]) + shape_invariants

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()

    # Build a `cond` wrapper that can handle the extra counter loop_var.
    def wrapped_cond(loop_counter, *args):
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      if maximum_iterations is None:
        return cond(*_pack_sequence_as(orig_loop_vars, args))
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations,
            cond(*_pack_sequence_as(orig_loop_vars, args)))

    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        loop_vars, {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileCondFuncGraph(cond_name),
        add_control_dependencies=add_control_dependencies)

    # Add external_captures of cond to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + cond_graph.external_captures
    shape_invariants = shape_invariants + type(shape_invariants)(
        [t.shape for t in cond_graph.external_captures])

    def wrapped_body(loop_counter, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(
          *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
      if not nest.is_sequence(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars))

      outputs = _tensor_array_to_flow(outputs)

      # Return the external_captures of cond_graph as is, i.e., treat them as
      # loop invariants.
      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1] + list(outputs) + list(
          args[len_orig_loop_vars:])

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        loop_vars, {},
        signature=_build_signature(loop_vars, shape_invariants),
        func_graph=util.WhileBodyFuncGraph(body_name),
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture `external_captures` of `body_graph` in `cond_graph` so that it
    # expects to receive those as arguments.
    # TODO(b/118457764): Dedup tensors that are captured in both the cond and
    # body. This logic already exists in cond_v2.
    with cond_graph.as_default():
      for external_capture in body_graph.external_captures:
        assert external_capture not in cond_graph.captures, (
            "Looks like both cond and body are capturing the same tensor %s. "
            "This is not supported yet. For now consider passing,"
            " this as a loop variable." % str(external_capture))
        cond_graph.capture(external_capture)

    # Export all tensors in the loop body that may be needed for gradient
    # computation. We do this by accumulating the intermediate values in
    # TensorLists.
    intermediate_tensors = _get_intermediates(body_graph)

    for intermediate_tensor in intermediate_tensors:
      tensor_list = list_ops.empty_tensor_list(
          element_dtype=intermediate_tensor.dtype,
          element_shape=intermediate_tensor.shape,
          max_num_elements=maximum_iterations)
      loop_vars.append(tensor_list)
      with cond_graph.as_default():
        # Add a placeholder to cond_graph's inputs corresponding to the
        # tensor_list.
        cond_graph.capture(tensor_list)
      with body_graph.as_default():
        # Push the intermediate tensor to the tensor list. This captures the
        # `tensor_list` as well.
        appended_tensor_list = list_ops.tensor_list_push_back(
            tensor_list,
            intermediate_tensor)
        # Add this modified tensor list to the list of outputs.
        body_graph.outputs.append(appended_tensor_list)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
    _check_shapes_compat(
        body_graph.outputs[1:1 + num_flattened_outputs],
        nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
        nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
    flattened_loop_vars = nest.flatten(loop_vars)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))

    outputs = gen_functional_ops._while(
        flattened_loop_vars,
        util.create_new_tf_function(cond_graph),
        util.create_new_tf_function(body_graph),
        output_shapes=[t.shape for t in body_graph.outputs],
        name=scope)

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  # First var is loop counter.
  outputs = _pack_sequence_as(orig_loop_vars,
                              outputs[1:1 + num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
示例#13
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  body_graph = _get_body_graph(op)

  # Set the incoming gradient of TensorArray handles to None. The gradient
  # implementation currently assumes all resource tensors correspond to float32
  # ResourceVariables, which can lead to runtime shape errors when used with a
  # TensorArray. This is a workaround until TensorArrays are reimplemented with
  # TensorLists instead of resources.
  # Also set the incoming gradient of non-trainable inputs to None. It is
  # possible that we receive non-None gradients for non-trainable types in
  # nested while loops because we accumulate outputs of the inner while as
  # variant tensors which are trainable and hence receive zeros_like tensors in
  # the gradient pass. The non-trainable tensors then receive the popped zeros
  # tensor from this zeros variant. The gradient for the loop vars corresponding
  # to these tensors is None or zeros (this happens only if the loop var is
  # accumulated as well) in _grad_fn so we reset these.
  # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
  # output grads in _grad_fn.
  grads = [
      None if _is_tensor_array_handle(output) or
      not gradients_impl.IsTrainable(output) else grad
      for grad, output in zip(grads, op.outputs)
  ]

  # Ensure that all non-resource trainable outputs have incoming gradients.
  assert all(g is not None or o.dtype == dtypes.resource or
             not gradients_impl.IsTrainable(o)
             for o, g in zip(op.outputs, grads)
            ), "All trainable loop vars must receive incoming gradients."
  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, body_graph,
      util.unique_grad_fn_name(body_graph.name), op)

  intermediate_tensors = _get_intermediates(body_grad_graph)

  maximum_iterations = op.get_attr(
      "_maximum_iterations") if _is_in_xla_context() else None
  assert not _is_in_xla_context() or maximum_iterations is not None
  for intermediate_tensor in intermediate_tensors:
    tensor_list = list_ops.empty_tensor_list(
        element_dtype=intermediate_tensor.dtype,
        element_shape=intermediate_tensor.shape,
        max_num_elements=maximum_iterations)

    with body_grad_graph.as_default():
      tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
      # Push the intermediate tensor to the tensor list.
      appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
                                                            intermediate_tensor)
      # Add this modified tensor list to the list of outputs.
      body_grad_graph.outputs.append(appended_tensor_list)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  loop_vars = args + body_grad_graph.external_captures
  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      name="%s_grad" % op.name)

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(outputs[0].op)
  _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]

  # Set None as the output gradient for tensors with None input gradient
  # e.g. TensorArray handles.
  # outputs[0] is the loop counter.
  # outputs[1] is the total number of loop iterations.
  index = 2
  none_padded_outputs = []
  for g in grads:
    if g is None:
      none_padded_outputs.append(None)
    else:
      none_padded_outputs.append(outputs[index])
      index += 1
  return none_padded_outputs
示例#14
0
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _check_same_outputs(true_graph, false_graph)

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match(true_graph, false_graph,
                                   true_inputs, false_inputs)

  # Add all intermediate tensors as function outputs so they're available for
  # the gradient computation. Since the outputs of the two functions must match,
  # we wrap all the intermediates in optionals. Each intermediate output will
  # have a value iff its corresponding branch is taken.

  true_intermediates = _get_intermediates(true_graph)
  false_intermediates = _get_intermediates(false_graph)

  # Save the original number of outputs to return to the caller.
  num_cond_outputs = len(true_graph.outputs)

  if control_flow_util.InXlaContext(ops.get_default_graph()):
    # XLA does not yet support optionals, so output intermediates directly and
    # make them match via FakeParams, which can be converted to zeros in XLA.
    # TODO(skyewm,jpienaar): can XLA support optionals?
    extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
        true_graph, false_graph, true_intermediates, false_intermediates)
  else:
    # Wrap intermediates in optionals.
    wrapped_true_intermediates = _wrap_intermediates(true_graph,
                                                     true_intermediates)
    wrapped_false_intermediates = _wrap_intermediates(false_graph,
                                                      false_intermediates)

    # Make outputs match by adding none optionals.
    extra_true_outputs, extra_false_outputs = _make_intermediates_match(
        true_graph, false_graph,
        wrapped_true_intermediates, wrapped_false_intermediates)

  true_graph.outputs.extend(extra_true_outputs)
  false_graph.outputs.extend(extra_false_outputs)
  # TODO(skyewm): somehow indicate it's a bug if this fails.
  _check_same_outputs(true_graph, false_graph)

  # Create the If op.
  tensors = gen_functional_ops._if(  # pylint: disable=protected-access
      pred,
      cond_inputs, [t.dtype for t in true_graph.outputs],
      util.create_new_tf_function(true_graph),
      util.create_new_tf_function(false_graph),
      output_shapes=_get_output_shapes(true_graph.outputs,
                                       false_graph.outputs),
      name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)

  return tensors[:num_cond_outputs]
示例#15
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
    """Like tf.while_loop, except emits a single While op."""
    maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)
    # Keep the original loop_vars around to know which args were TensorArrays.
    orig_loop_vars = loop_vars
    # Cache its length since we use it at multiple places below.
    len_orig_loop_vars = len(orig_loop_vars)

    # Convert TensorArrays to their flow variables. These get converted back to
    # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
    # `wrapped_body` below.
    loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
    loop_vars = nest.map_structure(
        ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(orig_loop_vars, shape_invariants)
    else:
        shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")

        loop_counter = constant_op.constant(
            0,
            dtype=maximum_iterations.dtype
            if maximum_iterations is not None else None,
            name="loop_counter")
        # Add loop counter needed for computing gradients.
        loop_vars = [loop_counter] + loop_vars

        shape_invariants = type(shape_invariants)([tensor_shape.scalar()
                                                   ]) + shape_invariants

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(loop_counter, *args):
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            if maximum_iterations is None:
                return cond(*_pack_sequence_as(orig_loop_vars, args))
            else:
                return math_ops.logical_and(
                    loop_counter < maximum_iterations,
                    cond(*_pack_sequence_as(orig_loop_vars, args)))

        cond_graph = func_graph_module.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileCondFuncGraph(cond_name),
            add_control_dependencies=add_control_dependencies)

        # Add external_captures of cond to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + cond_graph.external_captures
        shape_invariants = shape_invariants + type(shape_invariants)(
            [t.shape for t in cond_graph.external_captures])

        def wrapped_body(loop_counter, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        *args: List of args
          args[:len_orig_loop_vars] - Args for the original loop body.
          args[len_orig_loop_vars:] - External captures of cond. These get
            passed through as is.

      Returns:
        A list of tensors the same length as args.
      """
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            outputs = body(
                *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
            if not nest.is_sequence(outputs):
                outputs = [outputs]
            # Compare the structure of input and output of body converting the
            # top-level tuples to list to be compatible with legacy while_loop.
            nest.assert_same_structure(list(outputs), list(orig_loop_vars))

            outputs = _tensor_array_to_flow(outputs)

            # Return the external_captures of cond_graph as is, i.e., treat them as
            # loop invariants.
            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1] + list(outputs) + list(
                args[len_orig_loop_vars:])

        body_graph = func_graph_module.func_graph_from_py_func(
            body_name,
            wrapped_body,
            loop_vars, {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileBodyFuncGraph(body_name),
            add_control_dependencies=add_control_dependencies)
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture `external_captures` of `body_graph` in `cond_graph` so that it
        # expects to receive those as arguments.
        # TODO(b/118457764): Dedup tensors that are captured in both the cond and
        # body. This logic already exists in cond_v2.
        with cond_graph.as_default():
            for external_capture in body_graph.external_captures:
                assert external_capture not in cond_graph.captures, (
                    "Looks like both cond and body are capturing the same tensor %s. "
                    "This is not supported yet. For now consider passing,"
                    " this as a loop variable." % str(external_capture))
                cond_graph.capture(external_capture)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        num_flattened_outputs = len(nest.flatten(orig_loop_vars))
        _check_shapes_compat(
            body_graph.outputs[1:1 + num_flattened_outputs],
            nest.flatten(shape_invariants[1:1 + len_orig_loop_vars]),
            nest.flatten(loop_vars[1:1 + len_orig_loop_vars]))
        flattened_loop_vars = nest.flatten(loop_vars)
        _check_num_inputs_outputs(cond_graph, body_graph,
                                  len(flattened_loop_vars))

        outputs = gen_functional_ops._while(
            flattened_loop_vars,
            util.create_new_tf_function(cond_graph),
            util.create_new_tf_function(body_graph),
            output_shapes=[t.shape for t in body_graph.outputs],
            name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        util.maybe_set_lowering_attr(outputs[0].op)
        _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)

        # Return identities for each output of the While op, rather than the output
        # of the While op directly. This makes pruning work if the output of
        # while_loop() is fetched: the lowering pass converts the While outputs into
        # IdentityN outputs, which if fetched will cause all ops in the body to be
        # run (since it takes all exit ops as input). After lowering, each output
        # identity op will end up with only the appropriate exit op as input.
        outputs = tuple(array_ops.identity(t) for t in outputs)

    # First var is loop counter.
    outputs = _pack_sequence_as(orig_loop_vars,
                                outputs[1:1 + num_flattened_outputs])

    if return_same_structure:
        return outputs

    flattened_outputs = nest.flatten(outputs)
    if len(flattened_outputs) == 1:
        return flattened_outputs[0]
    else:
        return outputs
示例#16
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    _check_same_outputs(_COND, [true_graph, false_graph])

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match([true_graph, false_graph],
                                     [true_inputs, false_inputs])

    # Create the If op.
    with ops.control_dependencies(
            list(true_graph.control_captures) +
            list(false_graph.control_captures)):
        true_stateful_ops = [
            op for op in true_graph.get_operations() if op._is_stateful
        ]
        false_stateful_ops = [
            op for op in false_graph.get_operations() if op._is_stateful
        ]
        # TODO(srbs): Remove this after July 22, 2019. This is required to abide by
        # 3-week forward compat window of new TF python op generating code with
        # stale runtime binaries.
        if (true_stateful_ops or false_stateful_ops
                or not compat.forward_compatible(2019, 7, 22)):
            op_fn = gen_functional_ops._if
        else:
            op_fn = gen_functional_ops.stateless_if

        tensors = op_fn(pred,
                        cond_inputs, [t.dtype for t in true_graph.outputs],
                        util.create_new_tf_function(true_graph),
                        util.create_new_tf_function(false_graph),
                        output_shapes=_get_output_shapes(
                            true_graph.outputs, false_graph.outputs),
                        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)
    util.maybe_propagate_compile_time_consts_in_xla(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors)
示例#17
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  # Note that op is not always the same as while_op because the gradient tape,
  # for eager mode compatibility, forgets information about the proper op. Since
  # the loop cannot run in eager mode, however, we can safely introspect into
  # the graph here.
  while_op = op.outputs[0].op
  cond_graph = _get_graph(while_op, "cond")
  body_graph = _get_graph(while_op, "body")
  orig_num_params = len(body_graph.outputs)

  maximum_iterations = op.get_attr(
      "_maximum_iterations") if _is_in_xla_context() else None
  parallel_iterations = op.get_attr("parallel_iterations")
  assert not _is_in_xla_context() or maximum_iterations is not None
  maximum_iterations = _validate_and_convert_to_tensor(maximum_iterations)

  grads = [_preprocess_grad(grad, body_out, while_out)
           for grad, body_out, while_out
           in zip(grads, body_graph.outputs, while_op.outputs)]

  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, cond_graph, body_graph,
      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

  if body_grad_graph.while_op_needs_rewrite:
    # Modify 'op' to output the intermediate accumulators needed by the grad
    # function.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    cond_graph.name += "_rewritten"
    body_graph.name += "_rewritten"

    new_inputs = body_grad_graph.empty_tensor_lists
    new_outputs = body_graph.outputs[orig_num_params:]

    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
    while_op._set_type_list_attr("T", body_graph.output_types)
    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
    while_op._add_while_inputs(new_inputs)
    while_op._add_outputs([t.dtype for t in new_outputs],
                          [t.shape for t in new_outputs])
    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                           while_op)
  loop_vars = args + captured_inputs

  # This modifies body_grad_graph.
  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
      grads, body_grad_graph, loop_vars, while_op.inputs)

  def grad_cond(counter, max_iters, *unused_args):
    return counter < max_iters

  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      parallel_iterations=parallel_iterations,
      name="%s_grad" % while_op.name)
  grad_op = outputs[0].op

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(grad_op)
  _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]
  return _get_structured_grad_output(outputs, grads, body_grad_graph)
示例#18
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True):
    """Like tf.while_loop, except emits a single While op."""
    # Keep the original loop_vars around to know which args were TensorArrays.
    orig_loop_vars = loop_vars
    # Cache its length since we use it at multiple places below.
    len_orig_loop_vars = len(orig_loop_vars)

    # Convert TensorArrays to their flow variables. These get converted back to
    # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
    # `wrapped_body` below.
    loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
    loop_vars = nest.map_structure(
        ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
    if shape_invariants is not None:
        nest.assert_same_structure(orig_loop_vars, shape_invariants)
    else:
        shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)

    if not name:
        name = "while"

    with ops.name_scope(name) as scope:
        with ops.name_scope(None):
            cond_name = util.unique_fn_name(scope, "cond")
            body_name = util.unique_fn_name(scope, "body")
        maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
            maximum_iterations)
        loop_counter = constant_op.constant(
            0,
            dtype=maximum_iterations_loop_var.dtype
            if maximum_iterations is not None else None,
            name="loop_counter")
        # Add loop counter needed for computing gradients.
        loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

        shape_invariants = type(shape_invariants)(
            [tensor_shape.scalar(),
             tensor_shape.scalar()]) + shape_invariants

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies

        # Build a `cond` wrapper that can handle the extra counter loop_var.
        def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            if maximum_iterations is None:
                return cond(*_pack_sequence_as(orig_loop_vars, args))
            else:
                return math_ops.logical_and(
                    loop_counter < maximum_iterations_arg,
                    cond(*_pack_sequence_as(orig_loop_vars, args)))

        # NOTE(skyewm): we set collections to the outer graph's collections for
        # compatibility with TPUEstimator.
        cond_graph = func_graph_module.func_graph_from_py_func(
            cond_name,
            wrapped_cond,
            [],  # We provide signature instead of args.
            {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileCondFuncGraph(
                cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies)

        def wrapped_body(loop_counter, maximum_iterations_arg, *args):
            """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
            # Capture the tensors already captured in cond_graph so that they appear
            # in the same order in body_graph.external_captures.
            for t in cond_graph.external_captures:
                ops.get_default_graph().capture(t)

            # Convert the flow variables in `args` to TensorArrays. `args` should
            # already have the same structure as `orig_loop_vars` but currently there
            # is no nest.zip so we call `_pack_sequence_as` which flattens both
            # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
            # and packs it into the structure of `orig_loop_vars`.
            outputs = body(*_pack_sequence_as(orig_loop_vars, args))
            if not nest.is_sequence(outputs):
                outputs = [outputs]
            # Compare the structure of input and output of body converting the
            # top-level tuples to list to be compatible with legacy while_loop.
            nest.assert_same_structure(list(outputs), list(orig_loop_vars))

            outputs = _tensor_array_to_flow(outputs)

            # TODO(srbs): Update lowering code to create _Enter nodes with
            # is_constant=True for inputs that are directly passed to outputs.
            return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

        body_graph = func_graph_module.func_graph_from_py_func(
            body_name,
            wrapped_body,
            [],  # We provide signature instead of args.
            {},
            signature=_build_signature(loop_vars, shape_invariants),
            func_graph=util.WhileBodyFuncGraph(
                body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies)
        # Add external captures of body to the list of loop vars.
        # Note that external tensors will be treated as loop invariants, i.e.,
        # the value of that tensor in each iteration is the same as it was at the
        # beginning of the loop execution.
        loop_vars = loop_vars + body_graph.external_captures
        # TODO(srbs): Update lowering code to create _Enter nodes with
        # is_constant=True for inputs that are directly passed to outputs.
        body_graph.outputs.extend(body_graph.internal_captures)

        # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
        # that it expects to receive those as arguments.
        with cond_graph.as_default():
            num_cond_captures = len(cond_graph.external_captures)
            assert (cond_graph.external_captures ==
                    body_graph.external_captures[:num_cond_captures])
            for body_capture in body_graph.external_captures[
                    num_cond_captures:]:
                assert body_capture not in cond_graph.captures
                cond_graph.capture(body_capture)

        # Make sure that the shapes of the loop outputs are compatible with the
        # shape invariants, or the shapes of the loop vars if the invariants are not
        # specified.
        num_flattened_outputs = len(nest.flatten(orig_loop_vars))
        # First var is loop counter and second var is maximum_iterations.
        first_loop_var_index = 2
        _check_shapes_compat(
            body_graph.outputs[first_loop_var_index:first_loop_var_index +
                               num_flattened_outputs],
            nest.flatten(
                shape_invariants[first_loop_var_index:first_loop_var_index +
                                 len_orig_loop_vars]),
            nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                                   len_orig_loop_vars]))
        flattened_loop_vars = nest.flatten(loop_vars)
        _check_num_inputs_outputs(cond_graph, body_graph,
                                  len(flattened_loop_vars))

        with ops.control_dependencies(
                list(cond_graph.control_captures) +
                list(body_graph.control_captures)):
            outputs = gen_functional_ops._while(
                flattened_loop_vars,
                util.create_new_tf_function(cond_graph),
                util.create_new_tf_function(body_graph),
                output_shapes=[t.shape for t in body_graph.outputs],
                parallel_iterations=parallel_iterations,
                name=scope)

        _copy_handle_data(body_graph.outputs, outputs)
        util.maybe_set_lowering_attr(outputs[0].op)
        util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

        # Return identities for each output of the While op, rather than the output
        # of the While op directly. This makes pruning work if the output of
        # while_loop() is fetched: the lowering pass converts the While outputs into
        # IdentityN outputs, which if fetched will cause all ops in the body to be
        # run (since it takes all exit ops as input). After lowering, each output
        # identity op will end up with only the appropriate exit op as input.
        outputs = tuple(array_ops.identity(t) for t in outputs)

    outputs = _pack_sequence_as(
        orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                                num_flattened_outputs])

    if return_same_structure:
        return outputs

    flattened_outputs = nest.flatten(outputs)
    if len(flattened_outputs) == 1:
        return flattened_outputs[0]
    else:
        return outputs
示例#19
0
def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
  """The gradient of a While op produced by while_loop."""
  # Note that op is not always the same as while_op because the gradient tape,
  # for eager mode compatibility, forgets information about the proper op. Since
  # the loop cannot run in eager mode, however, we can safely introspect into
  # the graph here.
  while_op = op.outputs[0].op
  cond_graph = _get_graph(while_op, "cond")
  body_graph = _get_graph(while_op, "body")
  orig_num_params = len(body_graph.outputs)

  maximum_iterations = op.inputs[1]
  parallel_iterations = op.get_attr("parallel_iterations")

  try:
    num_original_outputs = while_op.get_attr("_num_original_outputs")
  except:  # pylint: disable=bare-except
    num_original_outputs = len(while_op.outputs)

  num_intermediates = len(while_op.outputs) - num_original_outputs
  grads = [
      _preprocess_grad(grad, body_out, while_out)  # pylint: disable=g-complex-comprehension
      for grad, body_out, while_out in zip(
          grads[:num_original_outputs],
          body_graph.outputs[:num_original_outputs],
          while_op.outputs[:num_original_outputs])
  ] + [None] * num_intermediates

  # We compute the gradient for the sub-graph between trainable ys and xs
  # with non-None incoming gradients. We later pad the None's to the list of
  # outputs.
  ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip(
      body_graph.outputs, body_graph.inputs, grads) if grad is not None])

  body_grad_graph, args = _create_grad_func(
      ys, xs, non_none_grads, cond_graph, body_graph,
      util.unique_grad_fn_name(body_graph.name), op, maximum_iterations)

  if body_grad_graph.while_op_needs_rewrite:
    # Modify 'op' to output the intermediate accumulators needed by the grad
    # function.
    # NOTE(skyewm): if there are any active sessions, this modification to `op`
    # may make them unrunnable!

    cond_graph.name += "_rewritten"
    body_graph.name += "_rewritten"

    new_inputs = body_grad_graph.empty_tensor_lists
    new_outputs = body_graph.outputs[orig_num_params:]

    while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
    while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
    while_op._set_type_list_attr("T", body_graph.output_types)
    while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes)
    while_op._add_while_inputs(new_inputs)
    while_op._add_outputs([t.dtype for t in new_outputs],
                          [t.shape for t in new_outputs])
    _copy_handle_data(new_outputs, op.outputs[orig_num_params:])

  # Do not ingore grads wrt extra outputs when computing higher order
  # derivatives.
  while_op._set_attr("_num_original_outputs",
                     attr_value_pb2.AttrValue(i=len(while_op.outputs)))

  captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph,
                                           while_op)
  loop_vars = args + captured_inputs

  # This modifies body_grad_graph.
  loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
      grads, body_grad_graph, loop_vars, while_op.inputs)

  def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters,
                *unused_args):
    return counter < forward_loop_iters

  grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name)
  cond_grad_graph = func_graph_module.func_graph_from_py_func(
      grad_cond_name, grad_cond, loop_vars, {},
      func_graph=util.WhileCondFuncGraph(grad_cond_name))

  _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars))

  outputs = gen_functional_ops._while(
      loop_vars,
      util.create_new_tf_function(cond_grad_graph),
      util.create_new_tf_function(body_grad_graph),
      output_shapes=[t.shape for t in body_grad_graph.outputs],
      parallel_iterations=parallel_iterations,
      name="%s_grad" % while_op.name)
  grad_op = outputs[0].op

  _copy_handle_data(body_grad_graph.outputs, outputs)
  util.maybe_set_lowering_attr(grad_op)
  util.maybe_propagate_compile_time_consts_in_xla(grad_op)

  # See comment in while_loop.
  outputs = [array_ops.identity(t) for t in outputs]
  return _get_structured_grad_output(outputs, grads, body_grad_graph)
示例#20
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                building_gradient,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    building_gradient: Whether this is a gradient If op.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    _check_same_outputs(_COND, [true_graph, false_graph])

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match([true_graph, false_graph],
                                     [true_inputs, false_inputs])
    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)
    # We do not output intermediates of the gradient If op since this is just
    # for backwards compatibility with existing code.
    if not building_gradient and util.output_all_intermediates():
        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation. Since the outputs of the two functions must
        # match, we wrap all the intermediates in optionals. Each intermediate
        # output will have a value iff its corresponding branch is taken.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
            [true_graph, false_graph],
            [wrapped_true_intermediates, wrapped_false_intermediates])

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)
        _check_same_outputs(_COND, [true_graph, false_graph])

    # Create the If op.
    with ops.control_dependencies(
            list(true_graph.control_captures) +
            list(false_graph.control_captures)):
        true_stateful_ops = [
            op for op in true_graph.get_operations() if op._is_stateful
        ]
        false_stateful_ops = [
            op for op in false_graph.get_operations() if op._is_stateful
        ]
        if (true_stateful_ops or false_stateful_ops):
            op_fn = gen_functional_ops._if
        else:
            op_fn = gen_functional_ops.stateless_if

        tensors = op_fn(pred,
                        cond_inputs, [t.dtype for t in true_graph.outputs],
                        util.create_new_tf_function(true_graph),
                        util.create_new_tf_function(false_graph),
                        output_shapes=_get_output_shapes(
                            true_graph.outputs, false_graph.outputs),
                        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if_op._true_graph = true_graph
    if_op._false_graph = false_graph
    util.maybe_set_lowering_attr(if_op)
    util.maybe_propagate_compile_time_consts_in_xla(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
示例#21
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return tensors
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation. Since the outputs of the two functions must match,
    # we wrap all the intermediates in optionals. Each intermediate output will
    # have a value iff its corresponding branch is taken.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
        # XLA does not yet support optionals, so output intermediates directly and
        # make them match via FakeParams, which can be converted to zeros in XLA.
        # TODO(skyewm,jpienaar): can XLA support optionals?
        extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla(
            true_graph, false_graph, true_intermediates, false_intermediates)
    else:
        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(
            true_graph, false_graph, wrapped_true_intermediates,
            wrapped_false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    # TODO(skyewm): somehow indicate it's a bug if this fails.
    _check_same_outputs(true_graph, false_graph)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)

    return tensors[:num_cond_outputs]