Beispiel #1
0
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
  """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
  indexed_slice_indices = []
  current_index = 0
  true_outputs_flat_with_composites = nest.flatten(
      true_graph.structured_outputs, expand_composites=False)
  false_outputs_flat_with_composites = nest.flatten(
      false_graph.structured_outputs, expand_composites=False)
  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
  for idx, (true_out, false_out) in enumerate(
      zip(true_outputs_flat_with_composites,
          false_outputs_flat_with_composites)):
    if isinstance(true_out, ops.IndexedSlices) != isinstance(
        false_out, ops.IndexedSlices):
      raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
                      "  true_fn returned:  %s\n"
                      "  false_fn returned: %s" % (idx, true_out, false_out))
    if isinstance(true_out, ops.IndexedSlices):
      # indices is the second component of the composite tensor.
      indexed_slice_indices.append(current_index + 1)
    if nest.is_sequence_or_composite(true_out):
      current_index += len(nest.flatten(true_out, expand_composites=True))
    else:
      current_index += 1

  if not indexed_slice_indices:
    return

  if current_index != len(true_graph.outputs):
    raise ValueError("Insufficient elements in true_graph.outputs.\n"
                     "Expected: %i\n"
                     "Actual: %i" % (current_index, len(true_graph.outputs)))

  # Cast indices with mismatching types to int64.
  for index in indexed_slice_indices:
    if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" % str(true_graph.outputs[index].dtype))
    if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" % str(false_graph.outputs[index].dtype))
    if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
      if false_graph.outputs[index].dtype == dtypes.int32:
        with false_graph.as_default():
          false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index],
                                                     dtypes.int64)
      else:
        with true_graph.as_default():
          true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
                                                    dtypes.int64)

  true_graph.structured_outputs = func_graph_module.pack_sequence_as(
      true_graph.structured_outputs, true_graph.outputs)
  false_graph.structured_outputs = func_graph_module.pack_sequence_as(
      false_graph.structured_outputs, false_graph.outputs)
Beispiel #2
0
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
    """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
    assert branch_graphs
    indexed_slice_indices = []
    current_index = 0
    branch_outputs_flat_with_composites = [
        nest.flatten(branch_graph.structured_outputs, expand_composites=False)
        for branch_graph in branch_graphs
    ]
    outs_per_branch = [
        len(outs) for outs in branch_outputs_flat_with_composites
    ]
    assert len(set(outs_per_branch)) == 1, outs_per_branch
    # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
    for output_idx, branch_outs in enumerate(
            zip(*branch_outputs_flat_with_composites)):
        if len(set(isinstance(out, ops.IndexedSlices)
                   for out in branch_outs)) != 1:
            raise TypeError(
                "Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
                "  branches returned: {outputs}".format(
                    op_name="cond" if op_type == _COND else "switch_case",
                    output_idx=output_idx,
                    outputs=branch_outs))
        if isinstance(branch_outs[0], ops.IndexedSlices):
            # indices is the second component of the composite tensor.
            indexed_slice_indices.append(current_index + 1)
        if nest.is_sequence_or_composite(branch_outs[0]):
            current_index += len(
                nest.flatten(branch_outs[0], expand_composites=True))
        else:
            current_index += 1

    if not indexed_slice_indices:
        return

    if current_index != len(branch_graphs[0].outputs):
        raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
                         "Expected: %i\n"
                         "Actual: %i" %
                         (current_index, len(branch_graphs[0].outputs)))

    # Cast indices with mismatching types to int64.
    for index in indexed_slice_indices:
        if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
               for bg in branch_graphs):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" %
                str([bg.outputs[index].dtype for bg in branch_graphs]))
        if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
            for branch_graph in branch_graphs:
                if branch_graph.outputs[index].dtype == dtypes.int32:
                    with branch_graph.as_default():
                        branch_graph.outputs[index] = math_ops.cast(
                            branch_graph.outputs[index], dtypes.int64)

    for branch_graph in branch_graphs:
        branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
            branch_graph.structured_outputs, branch_graph.outputs)
Beispiel #3
0
def _pack_sequence_as(structured_outputs, op_outputs):
  """Packs the outputs of the gradient If/Case op.

  The branch functions may contain None's in the list of `structured_outputs`.
  `op_outputs` has those outputs missing. So we need to add those Nones to the
  list of `op_outputs` and then pack it in the same structure as
  `structured_outputs`.

  Args:
    structured_outputs: structured_outputs from one of the branch functions.
    op_outputs: List of output tensors of the op.

  Returns:
    `op_outputs` packed like `structured_outputs`.
  """
  outputs_with_nones = []
  counter = 0
  for output in nest.flatten(structured_outputs, expand_composites=True):
    if output is None:
      outputs_with_nones.append(None)
    else:
      outputs_with_nones.append(op_outputs[counter])
      counter += 1
  return func_graph_module.pack_sequence_as(structured_outputs,
                                            outputs_with_nones)
Beispiel #4
0
def _build_case(branch_index, branch_graphs, branch_inputs, name=None):
    """Creates an `Case` op from `branch_index`, branch graphs and inputs.

  Note that this modifies `branch_graphs` to make the inputs match, and to
  output all intermediates values so they're available for the gradient
  computation.

  `branch_graphs` need not have the same input types, but they must
  have the same outpute types.

  Args:
    branch_index: integer Tensor
    branch_graphs: List of FuncGraph
    branch_inputs: List of lists of Tensors to be passed to corresponding
      branch_graph as input.
    name: the name for the Case op.

  Returns:
    A list of Tensors which are the outputs of the Case op. Does not include
    added intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_CASE, branch_graphs)
    _check_same_outputs(_CASE, branch_graphs)

    # Add inputs to branch_graphs to make them match. Note that this modifies the
    # graphs in `branch_graphs`.
    case_inputs = _make_inputs_match(branch_graphs, branch_inputs)

    # Create the Case op.
    with ops.control_dependencies(
            sum((list(bg.control_captures) for bg in branch_graphs), [])):
        tensors = gen_functional_ops.case(
            branch_index,
            case_inputs, [t.dtype for t in branch_graphs[0].outputs],
            [util.create_new_tf_function(g) for g in branch_graphs],
            output_shapes=_get_output_shapes(
                *[g.outputs for g in branch_graphs]),
            name=name)

    # TODO(b/110167197): this requires Case to have at least 1 output
    case_op = tensors[0].op
    util.maybe_set_lowering_attr(case_op)
    util.maybe_propagate_compile_time_consts_in_xla(case_op)

    # Return identities for each output of the Case op, rather than the output of
    # the Case op directly. This makes pruning work if the output of switch_case()
    # is fetched: the lowering pass converts the Case outputs into IdentityN
    # outputs, which if fetched will cause all ops in the taken branch to be run
    # (since it takes all merge ops as input). After lowering, each output
    # identity op will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    case_op.graph.prevent_fetching(case_op)
    return func_graph_module.pack_sequence_as(
        branch_graphs[0].structured_outputs, tensors)
Beispiel #5
0
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
                name=None):
  """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
  _check_same_outputs(true_graph, false_graph)

  # Add inputs to true_graph and false_graph to make them match. Note that
  # this modifies true_graph and false_graph.
  cond_inputs = _make_inputs_match(true_graph, false_graph,
                                   true_inputs, false_inputs)

  # Create the If op.
  tensors = gen_functional_ops._if(  # pylint: disable=protected-access
      pred,
      cond_inputs, [t.dtype for t in true_graph.outputs],
      util.create_new_tf_function(true_graph),
      util.create_new_tf_function(false_graph),
      output_shapes=_get_output_shapes(true_graph.outputs,
                                       false_graph.outputs),
      name=name)

  # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
  if_op = tensors[0].op
  util.maybe_set_lowering_attr(if_op)

  # Return identities for each output of the If op, rather than the output of
  # the If op directly. This makes pruning work if the output of cond() is
  # fetched: the lowering pass converts the If outputs into IdentityN outputs,
  # which if fetched will cause all ops in the taken branch to be run (since
  # it takes all merge ops as input). After lowering, each output identity op
  # will end up with only the appropriate merge op as input.
  # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
  # correct output structure
  tensors = [array_ops.identity(t) for t in tensors]

  # Prevent fetching since the variant outputs can't be fetched directly.
  if_op.graph.prevent_fetching(if_op)
  return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                            tensors)
Beispiel #6
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
    """Like tf.cond, except emits a single If op."""
    if isinstance(pred, bool):
        raise TypeError("pred must not be a Python bool", pred)

    if not name:
        name = "cond"

    with ops.name_scope(name) as scope:
        true_name = util.unique_fn_name(scope, "true")
        false_name = util.unique_fn_name(scope, "false")

        # Automatic control dependencies are added in defuns, but not in v1
        # graphs. Propagate that behavior here.
        add_control_dependencies = ops.get_default_graph(
        )._add_control_dependencies
        pred = ops.convert_to_tensor(pred)

        true_graph = func_graph_module.func_graph_from_py_func(
            true_name,
            true_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                true_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)
        false_graph = func_graph_module.func_graph_from_py_func(
            false_name,
            false_fn,
            [],
            {},
            func_graph=util.CondBranchFuncGraph(
                false_name, collections=ops.get_default_graph()._collections),  # pylint: disable=protected-access
            add_control_dependencies=add_control_dependencies,
            op_return_value=pred)

        outputs = _build_cond(pred,
                              true_graph,
                              false_graph,
                              true_graph.external_captures,
                              false_graph.external_captures,
                              name=scope)

        return func_graph_module.pack_sequence_as(
            true_graph.structured_outputs, outputs)
Beispiel #7
0
  def func_wrapper(*args):
    args = _convert_to_list(args)
    with ops.name_scope(name) as scope:
      func_graph, captured_args = _compile_function(
          func, args, scope, [], allow_external_captures=True)

      with ops.control_dependencies(list(func_graph.control_captures)):
        outputs = gen_functional_ops.function(
            captured_args,
            to_apply=util.create_new_tf_function(func_graph),
            Tout=func_graph.output_types,
            output_shapes=func_graph.output_shapes)

        # pack_sequence_as requires a list of Tensors, but the gen_ operation
        # returns an Operation under some circumstances (probably when that
        # list would be empty)
        if isinstance(outputs, ops.Operation):
          outputs = outputs.outputs

      return func_graph_module.pack_sequence_as(func_graph.structured_outputs,
                                                outputs)
Beispiel #8
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)

    outputs = _build_cond(pred, true_graph, false_graph,
                          true_graph.external_captures,
                          false_graph.external_captures,
                          name=scope)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              outputs)
Beispiel #9
0
    def multi_conv_wrapper(*args):
      inner_options = options if options else {}

      if not isinstance(inner_options, dict):
        raise TypeError(
            "Expected the multi_conv `options` to be a `dict`, but got %s "
            "instead." % (str(inner_options)))

      option_proto = option_flag_pb2.PoplarOptionFlags()
      for key, value in inner_options.items():
        flag = option_proto.flags.add()
        flag.option = key
        flag.value = value

      def func_wrapper(*args):
        with op_util.gradient_override_scope(training=False):
          return inner_func(*args)

      args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access
      with ops.name_scope("multi_conv") as scope:
        func_graph, captured_args = functional_ops._compile_function(  # pylint: disable=protected-access
            func_wrapper,
            args,
            scope, [],
            allow_external_captures=True)

        with ops.control_dependencies(list(func_graph.control_captures)):
          outputs = gen_functional_ops.multi_conv(
              captured_args,
              to_apply=util.create_new_tf_function(func_graph),
              Tout=func_graph.output_types,
              output_shapes=func_graph.output_shapes,
              option_flags=json_format.MessageToJson(option_proto))

      return func_graph_module.pack_sequence_as(func_graph.structured_outputs,
                                                outputs)
Beispiel #10
0
def _pipeline_stage(func,
                    stage_id,
                    device_id,
                    args,
                    training,
                    infeed_queue=None,
                    outfeed_queue=None,
                    name=None):
  """Internal function for compiling a pipeline stage. This should not be called
  directly and doing so will result in undefined behaviour.

  Creates a pipeline stage.

  Args:
    func: function which will be executed as a stage.
    stage_id: Stage number.
    device_id: IPU the stage will be mapped to.
    args: arguments to the function.
    infeed_queue: optional IPUInfeedQueue, if passed, it is dequeued as part of
      this function.
    outfeed_queue: optional IPUOutfeedQueue, if passed, it is enqueued as part
      of this function.
    name: name of this pipeline sage.

  Returns:
    The values after execting func(args), or the control dependency if
    outfeed_queue is not None.

  """
  name = name if name else "pipeline_stage"
  args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access

  func_to_compile = func
  control_outputs = []
  # If we have an infeed, then we wrap the function in another function which
  # dequeues the infeed.
  if infeed_queue:

    def infeed_func_wrapper(*args):
      args = functional_ops._convert_to_list(args)  # pylint: disable=protected-access
      dequeue_ops = functional_ops._convert_to_list(infeed_queue._dequeue())  # pylint: disable=protected-access
      # Deal with the dequeue depending on whether it's a list or dict.
      if len(dequeue_ops) == 1 and isinstance(dequeue_ops[0], dict):
        kwargs = dequeue_ops[0]
        return func(*(args), **kwargs)
      return func(*(args + dequeue_ops))

    func_to_compile = infeed_func_wrapper

  # If we have an outfeed, then we wrap the function in another function which
  # enqueues the outfeed.
  if outfeed_queue:
    func = func_to_compile

    def outfeed_func_wrapper(*args, **kwargs):
      outputs = func(*args, **kwargs)
      # Check if there are output tensors - if there are then enqueue them.
      if not isinstance(outputs, ops.Operation):
        if not isinstance(outputs, dict):
          outputs = functional_ops._convert_to_list(outputs)  # pylint: disable=protected-access
        outputs = outfeed_queue.enqueue(outputs)
      control_outputs.append(outputs)

    func_to_compile = outfeed_func_wrapper

  def gradient_override_wrapper(*args, **kwargs):
    with op_util.gradient_override_scope(training):
      return func_to_compile(*args, **kwargs)

  with ops.name_scope(name) as scope:
    # pylint: disable=protected-access
    try:
      func_graph, captured_args = functional_ops._compile_function(
          gradient_override_wrapper, args, scope, control_outputs)
    except functional_ops._InvalidCaptureException as e:
      raise ValueError(
          "Trying to capture the tensor %s which is not a resource. This tensor"
          " needs to be passed as either part of the `input` or `infeed_queue`"
          " of the pipeline." % (str(e)))
    # pylint: enable=protected-access

    # Create the pipeline stage and lower the function into XLA.
    with ops.control_dependencies(list(func_graph.control_captures)):
      with scopes.ipu_shard(device_id):
        outputs = gen_functional_ops.pipeline_stage(
            captured_args,
            to_apply=util.create_new_tf_function(func_graph),
            Tout=func_graph.output_types,
            output_shapes=func_graph.output_shapes,
            stage_id=stage_id)
    if isinstance(outputs, ops.Operation):
      return outputs
    return func_graph_module.pack_sequence_as(func_graph.structured_outputs,
                                              outputs)
Beispiel #11
0
def cond_v2(pred, true_fn, false_fn, name="cond"):
  """Like tf.cond, except emits a single If op."""
  if isinstance(pred, bool):
    raise TypeError("pred must not be a Python bool", pred)

  if not name:
    name = "cond"

  with ops.name_scope(name) as scope:
    true_name = util.unique_fn_name(scope, "true")
    false_name = util.unique_fn_name(scope, "false")

    # Automatic control dependencies are added in defuns, but not in v1
    # graphs. Propagate that behavior here.
    add_control_dependencies = util.in_defun()
    pred = ops.convert_to_tensor(pred)

    true_graph = func_graph_module.func_graph_from_py_func(
        true_name,
        true_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            true_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    false_graph = func_graph_module.func_graph_from_py_func(
        false_name,
        false_fn, [], {},
        func_graph=util.CondBranchFuncGraph(
            false_name, read_only_collections=False),
        add_control_dependencies=add_control_dependencies,
        op_return_value=pred)
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph,
                                     true_graph.external_captures,
                                     false_graph.external_captures)

    # Add all intermediate tensors as function outputs so they're available for
    # the gradient computation.

    true_intermediates = _get_intermediates(true_graph)
    false_intermediates = _get_intermediates(false_graph)

    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)

    # Make the number/type of new intermediate outputs match.
    extra_true_outputs, extra_false_outputs = _pad_params(
        true_graph, false_graph, true_intermediates, false_intermediates)

    true_graph.outputs.extend(extra_true_outputs)
    false_graph.outputs.extend(extra_false_outputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=scope)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    util.maybe_set_lowering_attr(tensors[0].op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = tuple(array_ops.identity(t) for t in tensors)

    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
Beispiel #12
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                building_gradient,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    building_gradient: Whether this is a gradient If op.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    _check_same_outputs(_COND, [true_graph, false_graph])

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match([true_graph, false_graph],
                                     [true_inputs, false_inputs])
    # Save the original number of outputs to return to the caller.
    num_cond_outputs = len(true_graph.outputs)
    # We do not output intermediates of the gradient If op since this is just
    # for backwards compatibility with existing code.
    if not building_gradient and util.output_all_intermediates():
        # Add all intermediate tensors as function outputs so they're available for
        # the gradient computation. Since the outputs of the two functions must
        # match, we wrap all the intermediates in optionals. Each intermediate
        # output will have a value iff its corresponding branch is taken.

        true_intermediates = _get_intermediates(true_graph)
        false_intermediates = _get_intermediates(false_graph)

        # Wrap intermediates in optionals.
        wrapped_true_intermediates = _wrap_intermediates(
            true_graph, true_intermediates)
        wrapped_false_intermediates = _wrap_intermediates(
            false_graph, false_intermediates)

        # Make outputs match by adding none optionals.
        extra_true_outputs, extra_false_outputs = _make_intermediates_match(  # pylint: disable=unbalanced-tuple-unpacking
            [true_graph, false_graph],
            [wrapped_true_intermediates, wrapped_false_intermediates])

        true_graph.outputs.extend(extra_true_outputs)
        false_graph.outputs.extend(extra_false_outputs)
        _check_same_outputs(_COND, [true_graph, false_graph])

    # Create the If op.
    with ops.control_dependencies(
            list(true_graph.control_captures) +
            list(false_graph.control_captures)):
        true_stateful_ops = [
            op for op in true_graph.get_operations() if op._is_stateful
        ]
        false_stateful_ops = [
            op for op in false_graph.get_operations() if op._is_stateful
        ]
        if (true_stateful_ops or false_stateful_ops):
            op_fn = gen_functional_ops._if
        else:
            op_fn = gen_functional_ops.stateless_if

        tensors = op_fn(pred,
                        cond_inputs, [t.dtype for t in true_graph.outputs],
                        util.create_new_tf_function(true_graph),
                        util.create_new_tf_function(false_graph),
                        output_shapes=_get_output_shapes(
                            true_graph.outputs, false_graph.outputs),
                        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    if_op._true_graph = true_graph
    if_op._false_graph = false_graph
    util.maybe_set_lowering_attr(if_op)
    util.maybe_propagate_compile_time_consts_in_xla(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors[:num_cond_outputs])
Beispiel #13
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
    _check_same_outputs(_COND, [true_graph, false_graph])

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match([true_graph, false_graph],
                                     [true_inputs, false_inputs])

    # Create the If op.
    with ops.control_dependencies(
            list(true_graph.control_captures) +
            list(false_graph.control_captures)):
        true_stateful_ops = [
            op for op in true_graph.get_operations() if op._is_stateful
        ]
        false_stateful_ops = [
            op for op in false_graph.get_operations() if op._is_stateful
        ]
        # TODO(srbs): Remove this after July 22, 2019. This is required to abide by
        # 3-week forward compat window of new TF python op generating code with
        # stale runtime binaries.
        if (true_stateful_ops or false_stateful_ops
                or not compat.forward_compatible(2019, 7, 22)):
            op_fn = gen_functional_ops._if
        else:
            op_fn = gen_functional_ops.stateless_if

        tensors = op_fn(pred,
                        cond_inputs, [t.dtype for t in true_graph.outputs],
                        util.create_new_tf_function(true_graph),
                        util.create_new_tf_function(false_graph),
                        output_shapes=_get_output_shapes(
                            true_graph.outputs, false_graph.outputs),
                        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)
    util.maybe_propagate_compile_time_consts_in_xla(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors)
Beispiel #14
0
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
    """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
    indexed_slice_indices = []
    current_index = 0
    true_outputs_flat_with_composites = nest.flatten(
        true_graph.structured_outputs, expand_composites=False)
    false_outputs_flat_with_composites = nest.flatten(
        false_graph.structured_outputs, expand_composites=False)
    # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
    for idx, (true_out, false_out) in enumerate(
            zip(true_outputs_flat_with_composites,
                false_outputs_flat_with_composites)):
        if isinstance(true_out, ops.IndexedSlices) != isinstance(
                false_out, ops.IndexedSlices):
            raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
                            "  true_fn returned:  %s\n"
                            "  false_fn returned: %s" %
                            (idx, true_out, false_out))
        if isinstance(true_out, ops.IndexedSlices):
            # indices is the second component of the composite tensor.
            indexed_slice_indices.append(current_index + 1)
        if nest.is_sequence_or_composite(true_out):
            current_index += len(nest.flatten(true_out,
                                              expand_composites=True))
        else:
            current_index += 1

    if not indexed_slice_indices:
        return

    if current_index != len(true_graph.outputs):
        raise ValueError("Insufficient elements in true_graph.outputs.\n"
                         "Expected: %i\n"
                         "Actual: %i" %
                         (current_index, len(true_graph.outputs)))

    # Cast indices with mismatching types to int64.
    for index in indexed_slice_indices:
        if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" % str(true_graph.outputs[index].dtype))
        if false_graph.outputs[index].dtype not in (dtypes.int32,
                                                    dtypes.int64):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" % str(false_graph.outputs[index].dtype))
        if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
            if false_graph.outputs[index].dtype == dtypes.int32:
                with false_graph.as_default():
                    false_graph.outputs[index] = math_ops.cast(
                        false_graph.outputs[index], dtypes.int64)
            else:
                with true_graph.as_default():
                    true_graph.outputs[index] = math_ops.cast(
                        true_graph.outputs[index], dtypes.int64)

    true_graph.structured_outputs = func_graph_module.pack_sequence_as(
        true_graph.structured_outputs, true_graph.outputs)
    false_graph.structured_outputs = func_graph_module.pack_sequence_as(
        false_graph.structured_outputs, false_graph.outputs)
Beispiel #15
0
def _build_cond(pred,
                true_graph,
                false_graph,
                true_inputs,
                false_inputs,
                name=None):
    """Creates an If op from the specified predicate, branch functions and inputs.

  Note that this modifies true_graph and false_graph to make the inputs match,
  and to output all intermediates values so they're available for the gradient
  computation.

  true_graph and false_graph need not have the same input types, but they must
  have the same outpute types.

  Args:
    pred: boolean Tensor
    true_graph: FuncGraph
    false_graph: FuncGraph
    true_inputs: a list of Tensors to be passed to true_graph as input.
    false_inputs: a list of Tensors to be passed to false_graph as input.
    name: the name for the If op.

  Returns:
    A list of Tensors which are the outputs of the If op. Does not include added
    intermediate outputs.
  """
    _check_same_outputs(true_graph, false_graph)

    # Add inputs to true_graph and false_graph to make them match. Note that
    # this modifies true_graph and false_graph.
    cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs,
                                     false_inputs)

    # Create the If op.
    tensors = gen_functional_ops._if(  # pylint: disable=protected-access
        pred,
        cond_inputs, [t.dtype for t in true_graph.outputs],
        util.create_new_tf_function(true_graph),
        util.create_new_tf_function(false_graph),
        output_shapes=_get_output_shapes(true_graph.outputs,
                                         false_graph.outputs),
        name=name)

    # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
    if_op = tensors[0].op
    util.maybe_set_lowering_attr(if_op)

    # Return identities for each output of the If op, rather than the output of
    # the If op directly. This makes pruning work if the output of cond() is
    # fetched: the lowering pass converts the If outputs into IdentityN outputs,
    # which if fetched will cause all ops in the taken branch to be run (since
    # it takes all merge ops as input). After lowering, each output identity op
    # will end up with only the appropriate merge op as input.
    # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
    # correct output structure
    tensors = [array_ops.identity(t) for t in tensors]

    # Prevent fetching since the variant outputs can't be fetched directly.
    if_op.graph.prevent_fetching(if_op)
    return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
                                              tensors)