Example #1
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                building_gradient,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    building_gradient: Whether this is a gradient If op.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    _check_same_outputs(_COND, [true_graph, false_graph])

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match([true_graph, false_graph],
                                     [true_inputs, false_inputs])
    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)
    # We do not output intermediates of the gradient If op since this is just
    # for backwards compatibility with existing code.
    if not building_gradient and util.output_all_intermediates():
        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation. Since the outputs of the two functions must
        # match, we wrap all the intermediates in optionals. Each intermediate
        # output will have a value iff its corresponding branch is taken.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
            [true_graph, false_graph],
            [wrapped_true_intermediates, wrapped_false_intermediates])

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)
        _check_same_outputs(_COND, [true_graph, false_graph])

    # Create the If op.
    with ops.control_dependencies(
            list(true_graph.control_captures) +
            list(false_graph.control_captures)):
        true_stateful_ops = [
            op for op in true_graph.get_operations() if op._is_stateful
        ]
        false_stateful_ops = [
            op for op in false_graph.get_operations() if op._is_stateful
        ]
        if (true_stateful_ops or false_stateful_ops):
            op_fn = gen_functional_ops._if
        else:
            op_fn = gen_functional_ops.stateless_if

        tensors = op_fn(pred,
                        cond_inputs, [t.dtype for t in true_graph.outputs],
                        util.create_new_tf_function(true_graph),
                        util.create_new_tf_function(false_graph),
                        output_shapes=_get_output_shapes(
                            true_graph.outputs, false_graph.outputs),
                        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if_op._true_graph = true_graph
    if_op._false_graph = false_graph
    util.maybe_set_lowering_attr(if_op)
    util.maybe_propagate_compile_time_consts_in_xla(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
Example #2
0
def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               maximum_iterations=None,
               name=None,
               return_same_structure=True,
               back_prop=True):
  """Like tf.while_loop, except emits a single While op."""
  # Keep the original loop_vars around to know which args were TensorArrays.
  orig_loop_vars = loop_vars
  # Cache its length since we use it at multiple places below.
  len_orig_loop_vars = len(orig_loop_vars)

  # Convert TensorArrays to their flow variables. These get converted back to
  # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and
  # `wrapped_body` below.
  loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
  loop_vars = nest.map_structure(
      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
      expand_composites=True)
  if shape_invariants is not None:
    nest.assert_same_structure(orig_loop_vars, shape_invariants,
                               expand_composites=False)
    signature = nest.map_structure(
        control_flow_ops._shape_invariant_to_type_spec, loop_vars,
        list(shape_invariants), expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        list(shape_invariants), expand_composites=False)

  else:
    signature = nest.map_structure(
        type_spec.type_spec_from_value, loop_vars, expand_composites=False)
    shape_invariants = nest.map_structure(
        control_flow_ops._get_shape_invariant, loop_vars,
        expand_composites=False)
  if not name:
    name = "while"

  with ops.name_scope(name) as scope:
    with ops.name_scope(None):
      cond_name = util.unique_fn_name(scope, "cond")
      body_name = util.unique_fn_name(scope, "body")
    maximum_iterations_loop_var = _build_maximum_iterations_loop_var(
        maximum_iterations)
    loop_counter = constant_op.constant(
        0,
        dtype=maximum_iterations_loop_var.dtype
        if maximum_iterations is not None else None,
        name="loop_counter")
    # Add loop counter needed for computing gradients.
    loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars

    shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants
    signature = (
        [tensor_spec.TensorSpec.from_tensor(loop_counter),
         tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] +
        signature)

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = ops.get_default_graph()._add_control_dependencies

    def wrapped_cond(loop_counter, maximum_iterations_arg, *args):
      """Extra `cond` wrapper that can handle the extra counter loop_var."""
      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      pred = cond(*_pack_sequence_as(orig_loop_vars, args))
      if (tensor_util.is_tensor(pred) and
          (pred.shape.dims is None or pred.shape.dims)):
        pred = array_ops.squeeze_v2(pred)

      if maximum_iterations is None:
        return pred
      else:
        return math_ops.logical_and(
            loop_counter < maximum_iterations_arg, pred)

    # NOTE(skyewm): we set collections to the outer graph's collections for
    # compatibility with TPUEstimator.
    cond_graph = func_graph_module.func_graph_from_py_func(
        cond_name,
        wrapped_cond,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileCondFuncGraph(
            cond_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)

    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)

    body_graph = func_graph_module.func_graph_from_py_func(
        body_name,
        wrapped_body,
        [],  # We provide signature instead of args.
        {},
        signature=signature,
        func_graph=util.WhileBodyFuncGraph(
            body_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
        add_control_dependencies=add_control_dependencies)
    # Add external captures of body to the list of loop vars.
    # Note that external tensors will be treated as loop invariants, i.e.,
    # the value of that tensor in each iteration is the same as it was at the
    # beginning of the loop execution.
    loop_vars = loop_vars + body_graph.external_captures
    # TODO(srbs): Update lowering code to create _Enter nodes with
    # is_constant=True for inputs that are directly passed to outputs.
    body_graph.outputs.extend(body_graph.internal_captures)

    # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
    # that it expects to receive those as arguments.
    with cond_graph.as_default():
      num_cond_captures = len(cond_graph.external_captures)
      assert (cond_graph.external_captures ==
              body_graph.external_captures[:num_cond_captures])
      cond_graph_captures = object_identity.ObjectIdentitySet(
          cond_graph.external_captures)
      for body_capture in body_graph.external_captures[num_cond_captures:]:
        assert body_capture not in cond_graph_captures
        cond_graph.capture(body_capture)

    # Make sure that the shapes of the loop outputs are compatible with the
    # shape invariants, or the shapes of the loop vars if the invariants are not
    # specified.
    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
                                             expand_composites=True))
    # First var is loop counter and second var is maximum_iterations.
    first_loop_var_index = 2
    _check_shapes_compat(
        body_graph.outputs[first_loop_var_index:first_loop_var_index +
                           num_flattened_outputs],
        nest.flatten(
            shape_invariants[first_loop_var_index:first_loop_var_index +
                             len_orig_loop_vars], expand_composites=True),
        nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
                               len_orig_loop_vars], expand_composites=True))

    num_original_outputs = len(body_graph.outputs)
    if back_prop and util.output_all_intermediates():
      # Export all tensors in the loop body that may be needed for gradient
      # computation. We do this by accumulating the intermediate values in
      # TensorLists.
      intermediate_tensors = _get_intermediates(body_graph)

      for intermediate_tensor in intermediate_tensors:
        tensor_list = list_ops.empty_tensor_list(
            element_dtype=intermediate_tensor.dtype,
            element_shape=intermediate_tensor.shape,
            max_num_elements=maximum_iterations)
        loop_vars.append(tensor_list)
        with cond_graph.as_default():
          # Add a placeholder to cond_graph's inputs corresponding to the
          # tensor_list.
          cond_graph.capture(tensor_list)
        with body_graph.as_default():
          # Push the intermediate tensor to the tensor list. This captures the
          # `tensor_list` as well.
          appended_tensor_list = list_ops.tensor_list_push_back(
              tensor_list, intermediate_tensor)
          # Add this modified tensor list to the list of outputs.
          body_graph.outputs.append(appended_tensor_list)

    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
    _check_num_inputs_outputs(cond_graph, body_graph,
                              len(flattened_loop_vars))
    _check_inputs_outputs_types_match(body_graph, flattened_loop_vars)

    with ops.control_dependencies(
        list(cond_graph.control_captures) + list(body_graph.control_captures)):
      output_shapes = [t.shape for t in body_graph.outputs]
      orig_loop_vars_range = slice(first_loop_var_index,
                                   first_loop_var_index + num_flattened_outputs)
      output_shapes[orig_loop_vars_range] = nest.flatten(
          shape_invariants, expand_composites=True)[orig_loop_vars_range]

      cond_stateful_ops = [
          op for op in cond_graph.get_operations() if op._is_stateful
      ]
      body_stateful_ops = [
          op for op in body_graph.get_operations() if op._is_stateful
      ]
      if (cond_stateful_ops or body_stateful_ops):
        op_fn = gen_functional_ops._while
      else:
        op_fn = gen_functional_ops.stateless_while

      outputs = op_fn(
          flattened_loop_vars,
          util.create_new_tf_function(cond_graph),
          util.create_new_tf_function(body_graph),
          output_shapes=output_shapes,
          parallel_iterations=parallel_iterations,
          name=scope)
      # This is needed so we do not compute derivative wrt these extra outputs.
      outputs[0].op._set_attr("_num_original_outputs",
                              attr_value_pb2.AttrValue(i=num_original_outputs))

    _copy_handle_data(body_graph.outputs, outputs)
    util.maybe_set_lowering_attr(outputs[0].op)
    util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op)

    # Return identities for each output of the While op, rather than the output
    # of the While op directly. This makes pruning work if the output of
    # while_loop() is fetched: the lowering pass converts the While outputs into
    # IdentityN outputs, which if fetched will cause all ops in the body to be
    # run (since it takes all exit ops as input). After lowering, each output
    # identity op will end up with only the appropriate exit op as input.
    outputs = tuple(array_ops.identity(t) for t in outputs)

  outputs = _pack_sequence_as(
      orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index +
                              num_flattened_outputs])

  if return_same_structure:
    return outputs

  flattened_outputs = nest.flatten(outputs, expand_composites=True)
  if len(flattened_outputs) == 1:
    return flattened_outputs[0]
  else:
    return outputs
Example #3
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                building_gradient,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    building_gradient: Whether this is a gradient If op.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
  _check_same_outputs(_COND, [true_graph, false_graph])

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match([true_graph, false_graph],
                                   [true_inputs, false_inputs])
  # We do not output intermediates of the gradient If op since this is just
  # for backwards compatibility with existing code.
  if not building_gradient and util.output_all_intermediates():
    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation. Since the outputs of the two functions must
    # match, we wrap all the intermediates in optionals. Each intermediate
    # output will have a value iff its corresponding branch is taken.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Wrap intermediates in optionals.
    wrapped_true_intermediates = _wrap_intermediates(true_graph,
                                                     true_intermediates)
    wrapped_false_intermediates = _wrap_intermediates(false_graph,
                                                      false_intermediates)

    # Make outputs match by adding none optionals.
    extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
        [true_graph, false_graph],
        [wrapped_true_intermediates, wrapped_false_intermediates])

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)
    _check_same_outputs(_COND, [true_graph, false_graph])

  # Create the If op.
  with ops.control_dependencies(
      list(true_graph.control_captures) + list(false_graph.control_captures)):
    true_stateful_ops = [
        op for op in true_graph.get_operations() if op._is_stateful
    ]
    false_stateful_ops = [
        op for op in false_graph.get_operations() if op._is_stateful
    ]
    if (true_stateful_ops or false_stateful_ops):
      op_fn = gen_functional_ops._if
    else:
      op_fn = gen_functional_ops.stateless_if

    def _make_op(inputs):
      if_op, tensors = util.get_op_and_outputs(op_fn(
          pred,
          inputs, [t.dtype for t in true_graph.outputs],
          util.create_new_tf_function(true_graph),
          util.create_new_tf_function(false_graph),
          output_shapes=_get_output_shapes(true_graph.outputs,
                                           false_graph.outputs),
          name=name))
      _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs)
      # `if_op` is None if this is a `StatelessIf` op with no outputs.
      if if_op is not None:
        # The true and false graphs have already been created, and we need that
        # to happen before we know which tensors will be captured and so whether
        # to wrap the cond in a tf.function. Post-hoc mutation of the branch
        # `outer_graph` properties seems like the only option if we want to
        # conditionally wrap in a function.
        true_graph.outer_graph = ops.get_default_graph()
        false_graph.outer_graph = ops.get_default_graph()
        if_op._true_graph = true_graph
        if_op._false_graph = false_graph
        util.maybe_set_lowering_attr(if_op)
        util.maybe_propagate_compile_time_consts_in_xla(if_op)
        _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph])
        # Prevent fetching since the variant outputs can't be fetched directly.
        if_op.graph.prevent_fetching(if_op)
      return tensors
    tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  return _pack_sequence_as(true_graph.structured_outputs, tensors)