Ejemplo n.º 1
0
    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
Ejemplo n.º 2
0
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
    """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
    assert branch_graphs
    indexed_slice_indices = []
    current_index = 0
    branch_outputs_flat_with_composites = [
        nest.flatten(branch_graph.structured_outputs, expand_composites=False)
        for branch_graph in branch_graphs
    ]
    outs_per_branch = [
        len(outs) for outs in branch_outputs_flat_with_composites
    ]
    assert len(set(outs_per_branch)) == 1, outs_per_branch
    # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
    for output_idx, branch_outs in enumerate(
            zip(*branch_outputs_flat_with_composites)):
        if len(set(isinstance(out, ops.IndexedSlices)
                   for out in branch_outs)) != 1:
            raise TypeError(
                "Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
                "  branches returned: {outputs}".format(
                    op_name="cond" if op_type == _COND else "switch_case",
                    output_idx=output_idx,
                    outputs=branch_outs))
        if isinstance(branch_outs[0], ops.IndexedSlices):
            # indices is the second component of the composite tensor.
            indexed_slice_indices.append(current_index + 1)
        if nest.is_sequence_or_composite(branch_outs[0]):
            current_index += len(
                nest.flatten(branch_outs[0], expand_composites=True))
        else:
            current_index += 1

    if not indexed_slice_indices:
        return

    if current_index != len(branch_graphs[0].outputs):
        raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
                         "Expected: %i\n"
                         "Actual: %i" %
                         (current_index, len(branch_graphs[0].outputs)))

    # Cast indices with mismatching types to int64.
    for index in indexed_slice_indices:
        if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
               for bg in branch_graphs):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" %
                str([bg.outputs[index].dtype for bg in branch_graphs]))
        if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
            for branch_graph in branch_graphs:
                if branch_graph.outputs[index].dtype == dtypes.int32:
                    with branch_graph.as_default():
                        branch_graph.outputs[index] = math_ops.cast(
                            branch_graph.outputs[index], dtypes.int64)

    for branch_graph in branch_graphs:
        branch_graph.structured_outputs = func_graph_module.pack_sequence_as(
            branch_graph.structured_outputs, branch_graph.outputs)
Ejemplo n.º 3
0
    def wrapped_body(loop_counter, maximum_iterations_arg, *args):
      """Loop body augmented with counter update.

      Args:
        loop_counter: Loop counter which needs to be incremented in the body.
        maximum_iterations_arg: Maximum iterations of the loop.
        *args: List of args

      Returns:
        A list of tensors the same length as args.
      """
      # Capture the tensors already captured in cond_graph so that they appear
      # in the same order in body_graph.external_captures.
      for t in cond_graph.external_captures:
        ops.get_default_graph().capture(t)

      # Convert the flow variables in `args` to TensorArrays. `args` should
      # already have the same structure as `orig_loop_vars` but currently there
      # is no nest.zip so we call `_pack_sequence_as` which flattens both
      # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
      # and packs it into the structure of `orig_loop_vars`.
      outputs = body(*_pack_sequence_as(orig_loop_vars, args))
      if not nest.is_sequence_or_composite(outputs):
        outputs = [outputs]
      # Compare the structure of input and output of body converting the
      # top-level tuples to list to be compatible with legacy while_loop.
      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
                                 expand_composites=True)

      outputs = _tensor_array_to_flow(outputs)

      # TODO(srbs): Update lowering code to create _Enter nodes with
      # is_constant=True for inputs that are directly passed to outputs.
      return [loop_counter + 1, maximum_iterations_arg] + list(outputs)
Ejemplo n.º 4
0
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
  """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
  indexed_slice_indices = []
  current_index = 0
  true_outputs_flat_with_composites = nest.flatten(
      true_graph.structured_outputs, expand_composites=False)
  false_outputs_flat_with_composites = nest.flatten(
      false_graph.structured_outputs, expand_composites=False)
  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
  for idx, (true_out, false_out) in enumerate(
      zip(true_outputs_flat_with_composites,
          false_outputs_flat_with_composites)):
    if isinstance(true_out, ops.IndexedSlices) != isinstance(
        false_out, ops.IndexedSlices):
      raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
                      "  true_fn returned:  %s\n"
                      "  false_fn returned: %s" % (idx, true_out, false_out))
    if isinstance(true_out, ops.IndexedSlices):
      # indices is the second component of the composite tensor.
      indexed_slice_indices.append(current_index + 1)
    if nest.is_sequence_or_composite(true_out):
      current_index += len(nest.flatten(true_out, expand_composites=True))
    else:
      current_index += 1

  if not indexed_slice_indices:
    return

  if current_index != len(true_graph.outputs):
    raise ValueError("Insufficient elements in true_graph.outputs.\n"
                     "Expected: %i\n"
                     "Actual: %i" % (current_index, len(true_graph.outputs)))

  # Cast indices with mismatching types to int64.
  for index in indexed_slice_indices:
    if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" % str(true_graph.outputs[index].dtype))
    if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" % str(false_graph.outputs[index].dtype))
    if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
      if false_graph.outputs[index].dtype == dtypes.int32:
        with false_graph.as_default():
          false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index],
                                                     dtypes.int64)
      else:
        with true_graph.as_default():
          true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index],
                                                    dtypes.int64)

  true_graph.structured_outputs = func_graph_module.pack_sequence_as(
      true_graph.structured_outputs, true_graph.outputs)
  false_graph.structured_outputs = func_graph_module.pack_sequence_as(
      false_graph.structured_outputs, false_graph.outputs)
Ejemplo n.º 5
0
def _make_indexed_slices_indices_types_match(op_type, branch_graphs):
  """Match dtype of IndexedSlices.indices in outputs of branch_graphs."""
  assert branch_graphs
  # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`.
  indexed_slice_indices = []
  current_index = 0
  # Note that this still contains Nones. We leave those in so that error
  # messages contain the correct indices. We handle the Nones later when
  # updating `current_index`.
  branch_outputs_flat_with_composites = [
      nest.flatten(branch_graph.structured_outputs, expand_composites=False)
      for branch_graph in branch_graphs
  ]
  outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites]
  assert len(set(outs_per_branch)) == 1, outs_per_branch
  # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
  for output_idx, branch_outs in enumerate(
      zip(*branch_outputs_flat_with_composites)):
    if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1:
      raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n"
                      "  branches returned: {outputs}".format(
                          op_name="cond" if op_type == _COND else "switch_case",
                          output_idx=output_idx,
                          outputs=branch_outs))
    if isinstance(branch_outs[0], ops.IndexedSlices):
      # indices is the second component of the composite tensor.
      indexed_slice_indices.append(current_index + 1)
    if nest.is_sequence_or_composite(branch_outs[0]):
      current_index += len(nest.flatten(branch_outs[0], expand_composites=True))
    elif branch_outs[0] is not None:
      # `FuncGraph.outputs` does not contain Nones so no need to update the
      # counter in that case.
      current_index += 1

  if not indexed_slice_indices:
    return

  # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus
  # the Nones.
  if current_index != len(branch_graphs[0].outputs):
    raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n"
                     "Expected: %i\n"
                     "Actual: %i" %
                     (current_index, len(branch_graphs[0].outputs)))

  # Cast indices with mismatching types to int64.
  for index in indexed_slice_indices:
    if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64)
           for bg in branch_graphs):
      raise TypeError("Type of IndexedSlices.indices must be int32 or int64. "
                      "Found: %s" %
                      str([bg.outputs[index].dtype for bg in branch_graphs]))
    if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1:
      for branch_graph in branch_graphs:
        if branch_graph.outputs[index].dtype == dtypes.int32:
          with branch_graph.as_default():
            branch_graph.outputs[index] = math_ops.cast(
                branch_graph.outputs[index], dtypes.int64)

  for branch_graph in branch_graphs:
    branch_graph.structured_outputs = _pack_sequence_as(
        branch_graph.structured_outputs, branch_graph.outputs)
Ejemplo n.º 6
0
def _make_indexed_slices_indices_types_match(true_graph, false_graph):
    """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs."""
    indexed_slice_indices = []
    current_index = 0
    true_outputs_flat_with_composites = nest.flatten(
        true_graph.structured_outputs, expand_composites=False)
    false_outputs_flat_with_composites = nest.flatten(
        false_graph.structured_outputs, expand_composites=False)
    # Store indices of IndexedSlices.indices in `indexed_slice_indices`.
    for idx, (true_out, false_out) in enumerate(
            zip(true_outputs_flat_with_composites,
                false_outputs_flat_with_composites)):
        if isinstance(true_out, ops.IndexedSlices) != isinstance(
                false_out, ops.IndexedSlices):
            raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n"
                            "  true_fn returned:  %s\n"
                            "  false_fn returned: %s" %
                            (idx, true_out, false_out))
        if isinstance(true_out, ops.IndexedSlices):
            # indices is the second component of the composite tensor.
            indexed_slice_indices.append(current_index + 1)
        if nest.is_sequence_or_composite(true_out):
            current_index += len(nest.flatten(true_out,
                                              expand_composites=True))
        else:
            current_index += 1

    if not indexed_slice_indices:
        return

    if current_index != len(true_graph.outputs):
        raise ValueError("Insufficient elements in true_graph.outputs.\n"
                         "Expected: %i\n"
                         "Actual: %i" %
                         (current_index, len(true_graph.outputs)))

    # Cast indices with mismatching types to int64.
    for index in indexed_slice_indices:
        if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" % str(true_graph.outputs[index].dtype))
        if false_graph.outputs[index].dtype not in (dtypes.int32,
                                                    dtypes.int64):
            raise TypeError(
                "Type of IndexedSlices.indices must be int32 or int64. "
                "Found: %s" % str(false_graph.outputs[index].dtype))
        if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype:
            if false_graph.outputs[index].dtype == dtypes.int32:
                with false_graph.as_default():
                    false_graph.outputs[index] = math_ops.cast(
                        false_graph.outputs[index], dtypes.int64)
            else:
                with true_graph.as_default():
                    true_graph.outputs[index] = math_ops.cast(
                        true_graph.outputs[index], dtypes.int64)

    true_graph.structured_outputs = func_graph_module.pack_sequence_as(
        true_graph.structured_outputs, true_graph.outputs)
    false_graph.structured_outputs = func_graph_module.pack_sequence_as(
        false_graph.structured_outputs, false_graph.outputs)