Exemplo n.º 1
0
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Exemplo n.º 2
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None,
                         control_outputs=True):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    control_outputs: an object convertible to a control output dictionary
      (see function util.convert_to_control_outputs for more details).
      If the dictionary can be created, it will be used while walking the graph
      forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
  if not seed_ops: return set()
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs)

  seed_ops = frozenset(seed_ops)
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops
  def is_within(op):
    return within_ops is None or op in within_ops
  result = set(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None and op in control_outputs:
        for new_op in control_outputs[op]:
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    result.update(new_wave)
    wave = new_wave
  if not inclusive:
    result.difference_update(seed_ops)
  return result
Exemplo n.º 3
0
def get_within_boundary_ops(ops,
                            seed_ops,
                            boundary_ops=(),
                            inclusive=True,
                            control_inputs=False,
                            control_outputs=None,
                            control_ios=None):
  """Return all the `tf.Operation` within the given boundary.

  Args:
    ops: an object convertible to a list of `tf.Operation`. those ops define the
      set in which to perform the operation (if a `tf.Graph` is given, it
      will be converted to the list of all its operations).
    seed_ops: the operations from which to start expanding.
    boundary_ops: the ops forming the boundary.
    inclusive: if `True`, the result will also include the boundary ops.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of `util.ControlOutputs` or `None`. If not
      `None`, control outputs are enabled.
    control_ios:  An instance of `util.ControlOutputs` or `None`. If not
      `None`, both control inputs and control outputs are enabled. This is
      equivalent to set control_inputs to True and control_outputs to
      the `util.ControlOutputs` instance.
  Returns:
    All the `tf.Operation` surrounding the given ops.
  Raises:
    TypeError: if `ops` or `seed_ops` cannot be converted to a list of
      `tf.Operation`.
    ValueError: if the boundary is intersecting with the seeds.
  """
  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
                                               control_ios)
  ops = util.make_list_of_op(ops)
  seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
  boundary_ops = set(util.make_list_of_op(boundary_ops))
  res = set(seed_ops)
  if boundary_ops & res:
    raise ValueError("Boundary is intersecting with the seeds.")
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    ops_io = get_ops_ios(wave, control_inputs, control_outputs)
    for op in ops_io:
      if op in res:
        continue
      if op in boundary_ops:
        if inclusive:
          res.add(op)
      else:
        new_wave.add(op)
    res.update(new_wave)
    wave = new_wave
  return [op for op in ops if op in res]
Exemplo n.º 4
0
  def __init__(self, inside_ops=(), passthrough_ts=()):
    """Create a subgraph containing the given ops and the "passthrough" tensors.

    Args:
      inside_ops: an object convertible to a list of tf.Operation. This list
        defines all the operations in the subgraph.
      passthrough_ts: an object convertible to a list of tf.Tensor. This list
        define all the "passthrough" tensors. A passthrough tensor is a tensor
        which goes directly from the input of the subgraph to it output, without
        any intermediate operations. All the non passthrough tensors are
        silently ignored.
    Raises:
      TypeError: if inside_ops cannot be converted to a list of tf.Operation or
        if passthrough_ts cannot be converted to a list of tf.Tensor.
    """
    inside_ops = util.make_list_of_op(inside_ops)
    passthrough_ts = util.make_list_of_t(passthrough_ts)
    ops_and_ts = inside_ops + passthrough_ts
    if ops_and_ts:
      self._graph = util.get_unique_graph(ops_and_ts)
    else:
      self._graph = None
    self._ops = inside_ops

    # Compute inside and outside tensor
    inputs, outputs, insides = select.compute_boundary_ts(inside_ops)

    # Compute passthrough tensors, silently ignoring the non-passthrough ones.
    all_tensors = frozenset(inputs + outputs + list(insides))
    self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]

    # Set inputs and outputs.
    self._input_ts = inputs + self._passthrough_ts
    self._output_ts = outputs + self._passthrough_ts
Exemplo n.º 5
0
def get_ops_ios(ops, control_inputs=False, control_outputs=None,
                control_ios=None):
  """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of util.ControlOutputs or None. If not None,
      control outputs are enabled.
    control_ios:  An instance of util.ControlOutputs or None. If not None, both
      control inputs and control outputs are enabled. This is equivalent to set
      control_inputs to True and control_outputs to the util.ControlOutputs
      instance.
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
                                               control_ios)
  ops = util.make_list_of_op(ops)
  res = []
  for op in ops:
    util.concatenate_unique(res, [t.op for t in op.inputs])
    for t in op.outputs:
      util.concatenate_unique(res, t.consumers())
    if control_outputs is not None:
      util.concatenate_unique(res, control_outputs.get(op))
    if control_inputs:
      util.concatenate_unique(res, op.control_inputs)
  return res
Exemplo n.º 6
0
def get_within_boundary_ops(ops,
                            seed_ops,
                            boundary_ops,
                            inclusive=True,
                            control_outputs=True):
  """Return all the tf.Operation within the given boundary.

  Args:
    ops: an object convertible to a list of tf.Operation. those ops define the
      set in which to perform the operation (if a tf.Graph is given, it
      will be converted to the list of all its operations).
    seed_ops: the operations from which to start expanding.
    boundary_ops: the ops forming the boundary.
    inclusive: if True, the result will also include the boundary ops.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used while
      expanding.
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops or seed_ops cannot be converted to a list of tf.Operation.
    ValueError: if the boundary is intersecting with the seeds.
  """
  ops = util.make_list_of_op(ops)
  control_outputs = util.convert_to_control_outputs(ops, control_outputs)
  seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
  boundary_ops = set(util.make_list_of_op(boundary_ops))
  res = set(seed_ops)
  if boundary_ops & res:
    raise ValueError("Boundary is intersecting with the seeds.")
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    ops_io = get_ops_ios(wave, control_outputs)
    for op in ops_io:
      if op in res:
        continue
      if op in boundary_ops:
        if inclusive:
          res.add(op)
      else:
        new_wave.add(op)
    res.update(new_wave)
    wave = new_wave
  return res
Exemplo n.º 7
0
def compute_boundary_ts(ops):
  """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Note that a tensor can be both an inside tensor and an output tensor if it is
  consumed by operations both outside and inside of `ops`.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
      `outside_input_ts` is a Python list of input tensors;
      `outside_output_ts` is a python list of output tensors;
      `inside_ts` is a python list of inside tensors.
    Since a tensor can be both an inside tensor and an output tensor,
    `outside_output_ts` and `inside_ts` might intersect.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  input_ts = _get_input_ts(ops)
  output_ts = _get_output_ts(ops)
  output_ts_set = frozenset(output_ts)
  ops_set = frozenset(ops)

  # Compute inside tensors.
  inside_ts = []
  only_inside_ts = []
  for t in input_ts:
    # Skip if the input tensor is not also an output tensor.
    if t not in output_ts_set:
      continue
    # Mark as "inside".
    inside_ts.append(t)
    # Mark as "only inside" if the tensor is not both inside and output.
    consumers = frozenset(t.consumers())
    if consumers - ops_set:
      continue
    only_inside_ts.append(t)

  inside_ts_set = frozenset(inside_ts)
  only_inside_ts_set = frozenset(only_inside_ts)
  outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
  outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
  return outside_input_ts, outside_output_ts, inside_ts
Exemplo n.º 8
0
def _get_output_ts(ops):
  """Compute the list of unique output tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    The list of unique output tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = []
  for op in ops:
    ts += op.outputs
  return ts
Exemplo n.º 9
0
def get_output_ts(ops):
  """Compute the set of output tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    The set of output tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = set()
  for op in ops:
    ts.update(op.outputs)
  return ts
Exemplo n.º 10
0
def filter_ops(ops, positive_filter=None):
  """Get the ops passing the given filter.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding where to keep an operation or not.
  Returns:
    A list of selected tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  if positive_filter is not None:
    ops = [op for op in ops if positive_filter(op)]
  return ops
Exemplo n.º 11
0
def filter_ops(ops, positive_filter):
  """Get the ops passing the given filter.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding where to keep an operation or not.
      If True, all the operations are returned.
  Returns:
    A list of selected tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  if positive_filter is not True:  # pylint: disable=g-explicit-bool-comparison
    ops = [op for op in ops if positive_filter(op)]
  return ops
Exemplo n.º 12
0
def filter_ts_from_regex(ops, regex):
  r"""Get all the tensors linked to ops that match the given regex.

  Args:
    ops: an object convertible to a list of tf.Operation.
    regex: a regular expression matching the tensors' name.
      For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo"
      scope.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  regex_obj = make_regex(regex)
  return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
Exemplo n.º 13
0
def filter_ops_from_regex(ops, regex):
  """Get all the operations that match the given regex.

  Args:
    ops: an object convertible to a list of `tf.Operation`.
    regex: a regular expression matching the operation's name.
      For example, `"^foo(/.*)?$"` will match all the operations in the "foo"
      scope.
  Returns:
    A list of `tf.Operation`.
  Raises:
    TypeError: if ops cannot be converted to a list of `tf.Operation`.
  """
  ops = util.make_list_of_op(ops)
  regex_obj = make_regex(regex)
  return filter_ops(ops, lambda op: regex_obj.search(op.name))
Exemplo n.º 14
0
def filter_ops(ops, positive_filter):
  """Get the ops passing the given filter.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding where to keep an operation or not.
      If True, all the operations are returned.
  Returns:
    A list of selected tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  if positive_filter is not True:  # pylint: disable=g-explicit-bool-comparison
    ops = [op for op in ops if positive_filter(op)]
  return ops
Exemplo n.º 15
0
def filter_ops_from_regex(ops, regex):
  """Get all the operations that match the given regex.

  Args:
    ops: an object convertible to a list of tf.Operation.
    regex: a regular expression matching the operation's name.
      For example, "^foo(/.*)?$" will match all the operations in the "foo"
      scope.
  Returns:
    A list of tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  regex_obj = _make_regex(regex)
  return filter_ops(ops, lambda op: regex_obj.search(op.name))
Exemplo n.º 16
0
def compute_boundary_ts(ops, keep_order=False, ambiguous_are_outputs=True):
    """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Args:
    ops: an object convertible to a list of tf.Operation.
    keep_order: if True use ops to determine the order of the resulting input
      and output tensors.
    ambiguous_are_outputs: a tensor can have consumers both inside and outside
      ops. Such tensors are treated as outside tensor if inside_output_as_output
      is True, otherwise they are treated as inside tensor.
  Returns:
    A Python set (list if keep_order is True) of input tensors.
    A Python set (list if keep_order is True) of output tensors.
    A Python set of inside tensors.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    input_tensors = get_input_ts(ops)
    output_tensors = get_output_ts(ops)
    inside_tensors = input_tensors & output_tensors
    # deal with ambiguous tensors
    if ambiguous_are_outputs:
        inside_and_output_tensors = set()
        for t in inside_tensors:
            for op in t.consumers():
                if op not in ops:
                    inside_and_output_tensors.add(t)
                    break
        output_tensors |= inside_and_output_tensors
        inside_tensors -= inside_and_output_tensors
    outside_input_tensors = input_tensors - inside_tensors
    outside_output_tensors = output_tensors - inside_tensors
    if keep_order:
        outside_input_tensors = [
            t for t in input_tensors if t in outside_input_tensors
        ]
        outside_output_tensors = [
            t for t in output_tensors if t in outside_output_tensors
        ]
    return outside_input_tensors, outside_output_tensors, inside_tensors
Exemplo n.º 17
0
def _get_output_ts(ops):
  """Compute the list of unique output tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    The list of unique output tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = []
  for op in ops:
    for t in op.outputs:
      if t not in ts:
        ts.append(t)
  return ts
Exemplo n.º 18
0
def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True):
    """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Args:
    ops: an object convertible to a list of tf.Operation.
    ambiguous_ts_are_outputs: a tensor can have consumers both inside and
      outside ops. Such tensors are treated as outside tensor if
      ambiguous_ts_are_outputs is True, otherwise they are treated as
      inside tensor.
  Returns:
    A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
      `outside_input_ts` is a Python list of input tensors;
      `outside_output_ts` is a python list of output tensors;
      `inside_ts` is a python list of inside tensors.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    input_ts = _get_input_ts(ops)
    output_ts = _get_output_ts(ops)
    output_ts_set = frozenset(output_ts)
    ops_set = frozenset(ops)

    # fill in inside
    inside_ts = []
    for t in input_ts:
        # is also output?
        if t not in output_ts_set:
            continue
        # is ambiguous_ts_are_outputs is True, don't add to inside if ambiguous
        if ambiguous_ts_are_outputs:
            consumers = frozenset(t.consumers())
            if consumers - ops_set:
                continue
        inside_ts.append(t)

    inside_ts_set = frozenset(inside_ts)
    outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
    outside_output_ts = [t for t in output_ts if t not in inside_ts_set]
    return outside_input_ts, outside_output_ts, inside_ts
Exemplo n.º 19
0
def compute_boundary_ts(ops, ambiguous_ts_are_outputs=True):
  """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Args:
    ops: an object convertible to a list of tf.Operation.
    ambiguous_ts_are_outputs: a tensor can have consumers both inside and
      outside ops. Such tensors are treated as outside tensor if
      ambiguous_ts_are_outputs is True, otherwise they are treated as
      inside tensor.
  Returns:
    A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
      `outside_input_ts` is a Python list of input tensors;
      `outside_output_ts` is a python list of output tensors;
      `inside_ts` is a python list of inside tensors.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  input_ts = _get_input_ts(ops)
  output_ts = _get_output_ts(ops)
  output_ts_set = frozenset(output_ts)
  ops_set = frozenset(ops)

  # fill in inside
  inside_ts = []
  for t in input_ts:
    # is also output?
    if t not in output_ts_set:
      continue
    # is ambiguous_ts_are_outputs is True, don't add to inside if ambiguous
    if ambiguous_ts_are_outputs:
      consumers = frozenset(t.consumers())
      if consumers - ops_set:
        continue
    inside_ts.append(t)

  inside_ts_set = frozenset(inside_ts)
  outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
  outside_output_ts = [t for t in output_ts if t not in inside_ts_set]
  return outside_input_ts, outside_output_ts, inside_ts
Exemplo n.º 20
0
def compute_boundary_ts(ops, keep_order=False, ambiguous_are_outputs=True):
  """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Args:
    ops: an object convertible to a list of tf.Operation.
    keep_order: if True use ops to determine the order of the resulting input
      and output tensors.
    ambiguous_are_outputs: a tensor can have consumers both inside and outside
      ops. Such tensors are treated as outside tensor if inside_output_as_output
      is True, otherwise they are treated as inside tensor.
  Returns:
    A Python set (list if keep_order is True) of input tensors.
    A Python set (list if keep_order is True) of output tensors.
    A Python set of inside tensors.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  input_tensors = get_input_ts(ops)
  output_tensors = get_output_ts(ops)
  inside_tensors = input_tensors & output_tensors
  # deal with ambiguous tensors
  if ambiguous_are_outputs:
    inside_and_output_tensors = set()
    for t in inside_tensors:
      for op in t.consumers():
        if op not in ops:
          inside_and_output_tensors.add(t)
          break
    output_tensors |= inside_and_output_tensors
    inside_tensors -= inside_and_output_tensors
  outside_input_tensors = input_tensors - inside_tensors
  outside_output_tensors = output_tensors - inside_tensors
  if keep_order:
    outside_input_tensors = [t for t in input_tensors
                             if t in outside_input_tensors]
    outside_output_tensors = [t for t in output_tensors
                              if t in outside_output_tensors]
  return outside_input_tensors, outside_output_tensors, inside_tensors
Exemplo n.º 21
0
def filter_ts(ops, positive_filter=None):
    """Get all the tensors which are input or output of an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding whether to keep a tensor or not.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    tensors = set()
    tensors.update(get_input_ts(ops))
    tensors.update(get_output_ts(ops))
    if positive_filter is not None:
        tensors = [t for t in tensors if positive_filter(t)]
    return tensors
Exemplo n.º 22
0
def filter_ts(ops, positive_filter):
  """Get all the tensors which are input or output of an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding whether to keep a tensor or not.
      If True, all the tensors are returned.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = _get_input_ts(ops)
  util.concatenate_unique(ts, _get_output_ts(ops))
  if positive_filter is not True:
    ts = [t for t in ts if positive_filter(t)]
  return ts
Exemplo n.º 23
0
def filter_ts(ops, positive_filter=None):
  """Get all the tensors which are input or output of an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding whether to keep a tensor or not.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  tensors = set()
  tensors.update(get_input_ts(ops))
  tensors.update(get_output_ts(ops))
  if positive_filter is not None:
    tensors = [t for t in tensors if positive_filter(t)]
  return tensors
Exemplo n.º 24
0
def _get_input_ts(ops):
  """Compute the list of unique input tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    The list of unique input tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = []
  ts_set = set()
  for op in ops:
    for t in op.inputs:
      if t not in ts_set:
        ts.append(t)
        ts_set.add(t)
  return ts
Exemplo n.º 25
0
def add_control_inputs(op, cops):
    """Add the control inputs cops to op.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation to which the control inputs are added.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is already a control input of op.
  """
    if not isinstance(op, _tf_ops.Operation):
        raise TypeError("Expected a tf.Operation, got: {}", type(op))
    cops = _util.make_list_of_op(cops, allow_graph=False)
    for cop in cops:
        if cop in op.control_inputs:
            raise ValueError("{} is already a control_input of {}".format(
                cop.name, op.name))
    op._add_control_inputs(cops)  # pylint: disable=protected-access
Exemplo n.º 26
0
def add_control_inputs(op, cops):
  """Add the control inputs cops to op.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation to which the control inputs are added.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is already a control input of op.
  """
  if not isinstance(op, _tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = _util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop in op.control_inputs:
      raise ValueError("{} is already a control_input of {}".format(cop.name,
                                                                    op.name))
  op._add_control_inputs(cops)  # pylint: disable=protected-access
Exemplo n.º 27
0
    def __init__(self, inside_ops=(), passthrough_ts=()):
        """Create a subgraph containing the given ops and the "passthrough" tensors.

    Args:
      inside_ops: an object convertible to a list of `tf.Operation`. This list
        defines all the operations in the subgraph.
      passthrough_ts: an object convertible to a list of `tf.Tensor`. This list
        define all the "passthrough" tensors. A passthrough tensor is a tensor
        which goes directly from the input of the subgraph to it output, without
        any intermediate operations. All the non passthrough tensors are
        silently ignored.
    Raises:
      TypeError: if inside_ops cannot be converted to a list of `tf.Operation`
        or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`.
    """

        inside_ops = util.make_list_of_op(inside_ops)
        passthrough_ts = util.make_list_of_t(passthrough_ts)
        ops_and_ts = inside_ops + passthrough_ts
        if ops_and_ts:
            self._graph = util.get_unique_graph(ops_and_ts)
            self._ops = inside_ops

            # Compute inside and outside tensor
            inputs, outputs, insides = select.compute_boundary_ts(inside_ops)

            # Compute passthrough tensors, silently ignoring the non-passthrough ones.
            all_tensors = frozenset(inputs + outputs + list(insides))
            self._passthrough_ts = [
                t for t in passthrough_ts if t not in all_tensors
            ]

            # Set inputs and outputs.
            self._input_ts = inputs + self._passthrough_ts
            self._output_ts = outputs + self._passthrough_ts
        else:
            self._graph = None
            self._passthrough_ts = []
            self._input_ts = []
            self._output_ts = []
            self._ops = []
Exemplo n.º 28
0
def remove_control_inputs(op, cops):
  """Remove the control inputs cops from co.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    op: a tf.Operation from which to remove the control inputs.
    cops: an object convertible to a list of tf.Operation.
  Raises:
    TypeError: if op is not a tf.Operation
    ValueError: if any cop in cops is not a control input of op.
  """
  if not isinstance(op, tf_ops.Operation):
    raise TypeError("Expected a tf.Operation, got: {}", type(op))
  cops = util.make_list_of_op(cops, allow_graph=False)
  for cop in cops:
    if cop not in op.control_inputs:
      raise ValueError("{} is not a control_input of {}".format(op.name,
                                                                cop.name))
  # pylint: disable=protected-access
  op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
  op._recompute_node_def()
Exemplo n.º 29
0
def remove_control_inputs(op, cops):
    """Remove the control inputs cops from co.

  Warning: this function is directly manipulating the internals of the
  `tf.Graph`.

  Args:
    op: a `tf.Operation` from which to remove the control inputs.
    cops: an object convertible to a list of `tf.Operation`.
  Raises:
    TypeError: if op is not a `tf.Operation`.
    ValueError: if any cop in cops is not a control input of op.
  """
    if not isinstance(op, tf_ops.Operation):
        raise TypeError("Expected a tf.Operation, got: {}", type(op))
    cops = util.make_list_of_op(cops, allow_graph=False)
    for cop in cops:
        if cop not in op.control_inputs:
            raise ValueError("{} is not a control_input of {}".format(op.name, cop.name))
    # pylint: disable=protected-access
    op._control_inputs = [cop for cop in op._control_inputs if cop not in cops]
    op._recompute_node_def()
Exemplo n.º 30
0
def get_ops_ios(ops, control_outputs=True):
    """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used to determine
      the surrounding ops (in addition to the regular inputs and outputs).
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    control_outputs = util.convert_to_control_outputs(ops, control_outputs)
    ops = util.make_list_of_op(ops)
    res = set()
    for op in ops:
        res.update([t.op for t in op.inputs])
        for t in op.outputs:
            res.update(t.consumers())
        if control_outputs is not None and op in control_outputs:
            res.update(control_outputs[op])
    return res
Exemplo n.º 31
0
def get_ops_ios(ops, control_outputs=True):
  """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used to determine
      the surrounding ops (in addition to the regular inputs and outputs).
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  control_outputs = util.convert_to_control_outputs(ops, control_outputs)
  ops = util.make_list_of_op(ops)
  res = set()
  for op in ops:
    res.update([t.op for t in op.inputs])
    for t in op.outputs:
      res.update(t.consumers())
    if control_outputs is not None and op in control_outputs:
      res.update(control_outputs[op])
  return res
Exemplo n.º 32
0
def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
    """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1.

  This function is the back-bone of the Graph-Editor. It is essentially a thin
  wrapper on top of the tf.Operation._update_input.

  Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end
  of t0 and t1 in three possible ways:
  1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After
  this operation, the previous consumers of t0 are now consumers of t1 and
  vice-versa.
  2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the
  tensors's end of t1 (which are left dangling). After this operation, the
  previous consumers of t0 are still consuming t0 but the previous consumers of
  t1 are not also consuming t0. The tensor t1 has no consumer.
  3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode.

  Note that this function is re-routing the end of two tensors, not the start.
  Re-routing the start of two tensors is not supported by this library. The
  reason for that is the following: TensorFlow, by design, creates a strong bond
  between an op and its output tensor. This Graph editor follows this design and
  treats an operation A and its generating tensors {t_i} as an entity which
  cannot be broken. In other words, an op cannot be detached from any of its
  output tensors, ever. But it is possible to detach an op from its input
  tensors, which is what this function concerns itself with.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    mode: what to do with those tensors: "a->b" or "b<->a" for swaping and
      "a->b" or "b->a" for one direction re-routing.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`.
    TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be
      converted to a list of `tf.Operation`.
  """
    a2b, b2a = _RerouteMode.check(mode)
    ts0 = _util.make_list_of_t(ts0)
    ts1 = _util.make_list_of_t(ts1)
    _check_ts_compatibility(ts0, ts1)
    if cannot_modify is not None:
        cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
    if can_modify is not None:
        can_modify = frozenset(_util.make_list_of_op(can_modify))
    nb_update_inputs = 0
    precomputed_consumers = []
    # precompute consumers to avoid issue with repeated tensors:
    for t0, t1 in zip(ts0, ts1):
        consumers0 = set(t0.consumers())
        consumers1 = set(t1.consumers())
        precomputed_consumers.append((consumers0, consumers1))
    for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers):
        if t0 is t1:
            continue  # Silently ignore identical tensors.
        consumers0, consumers1 = consumers
        if a2b:
            nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify,
                                           cannot_modify)
        if b2a:
            nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify,
                                           cannot_modify)
    return nb_update_inputs
Exemplo n.º 33
0
def select_ops(*args, **kwargs):
    """Helper to select operations.

  Args:
    *args: list of 1) regular expressions (compiled or not) or  2) (array of)
      tf.Operation. tf.Tensor instances are silently ignored.
    **kwargs: 'graph': tf.Graph in which to perform the regex query.This is
      required when using regex.
      'positive_filter': an elem if selected only if positive_filter(elem) is
        True. This is optional.
      'restrict_regex': a regular expression is ignored if it doesn't start
        with the substring "(?#ops)".
  Returns:
    list of tf.Operation
  Raises:
    TypeError: if the optional keyword argument graph is not a tf.Graph
      or if an argument in args is not an (array of) tf.Operation
      or an (array of) tf.Tensor (silently ignored) or a string
      or a regular expression.
    ValueError: if one of the keyword arguments is unexpected or if a regular
      expression is used without passing a graph as a keyword argument.
  """
    # get keywords arguments
    graph = None
    positive_filter = None
    restrict_regex = False
    for k, v in kwargs.iteritems():
        if k == "graph":
            graph = v
            if graph is not None and not isinstance(graph, tf_ops.Graph):
                raise TypeError("Expected a tf.Graph, got: {}".format(
                    type(graph)))
        elif k == "positive_filter":
            positive_filter = v
        elif k == "restrict_regex":
            restrict_regex = v
        else:
            raise ValueError("Wrong keywords argument: {}.".format(k))

    ops = []

    for arg in args:
        if _can_be_regex(arg):
            if graph is None:
                raise ValueError(
                    "Use the keyword argument 'graph' to use regex.")
            regex = _make_regex(arg)
            if regex.pattern.startswith("(?#ts)"):
                continue
            if restrict_regex and not regex.pattern.startswith("(?#ops)"):
                continue
            ops_ = filter_ops_from_regex(graph, regex)
            for op_ in ops_:
                if op_ not in ops:
                    if positive_filter is None or positive_filter(op_):
                        ops.append(op_)
        else:
            ops_aux = util.make_list_of_op(arg, ignore_ts=True)
            if positive_filter is not None:
                ops_aux = [op for op in ops_aux if positive_filter(op)]
            ops_aux = [op for op in ops_aux if op not in ops]
            ops += ops_aux

    return ops
Exemplo n.º 34
0
def get_forward_walk_ops(seed_ops,
                         inclusive=True,
                         within_ops=None,
                         control_outputs=True):
    """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    control_outputs: an object convertible to a control output dictionary
      (see function util.convert_to_control_outputs for more details).
      If the dictionary can be created, it will be used while walking the graph
      forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
    if not seed_ops: return set()
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = get_consuming_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    control_outputs = util.convert_to_control_outputs(seed_ops,
                                                      control_outputs)

    seed_ops = frozenset(seed_ops)
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = set(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.outputs:
                for new_op in new_t.consumers():
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
            if control_outputs is not None and op in control_outputs:
                for new_op in control_outputs[op]:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        result.update(new_wave)
        wave = new_wave
    if not inclusive:
        result.difference_update(seed_ops)
    return result
Exemplo n.º 35
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None):
    """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
    _, control_outputs = check_cios(False, control_outputs)
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_consuming_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    seed_ops = frozenset(seed_ops)
    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.outputs:
                if new_t in stop_at_ts:
                    continue
                for new_op in new_t.consumers():
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
            if control_outputs is not None:
                for new_op in control_outputs.get(op):
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Exemplo n.º 36
0
def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
  """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1.

  This function is the back-bone of the Graph-Editor. It is essentially a thin
  wrapper on top of the tf.Operation._update_input.

  Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end
  of t0 and t1 in three possible ways:
  1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After
  this operation, the previous consumers of t0 are now consumers of t1 and
  vice-versa.
  2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the
  tensors's end of t1 (which are left dangling). After this operation, the
  previous consumers of t0 are still consuming t0 but the previous consumers of
  t1 are not also consuming t0. The tensor t1 has no consumer.
  3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode.

  Note that this function is re-routing the end of two tensors, not the start.
  Re-routing the start of two tensors is not supported by this library. The
  reason for that is the following: TensorFlow, by design, creates a strong bond
  between an op and its output tensor. This Graph editor follows this design and
  treats an operation A and its generating tensors {t_i} as an entity which
  cannot be broken. In other words, an op cannot be detached from any of its
  output tensors, ever. But it is possible to detach an op from its input
  tensors, which is what this function concerns itself with.

  Warning: this function is directly manipulating the internals of the tf.Graph.

  Args:
    ts0: an object convertible to a list of `tf.Tensor`.
    ts1: an object convertible to a list of `tf.Tensor`.
    mode: what to do with those tensors: "a->b" or "b<->a" for swaping and
      "a->b" or "b->a" for one direction re-routing.
    can_modify: iterable of operations which can be modified. Any operation
      outside within_ops will be left untouched by this function.
    cannot_modify: iterable of operations which cannot be modified.
      Any operation within cannot_modify will be left untouched by this
      function.
  Returns:
    The number of individual modifications made by the function.
  Raises:
    TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`.
    TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be
      converted to a list of `tf.Operation`.
  """
  a2b, b2a = _RerouteMode.check(mode)
  ts0 = _util.make_list_of_t(ts0)
  ts1 = _util.make_list_of_t(ts1)
  _check_ts_compatibility(ts0, ts1)
  if cannot_modify is not None:
    cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
  if can_modify is not None:
    can_modify = frozenset(_util.make_list_of_op(can_modify))
  nb_update_inputs = 0
  precomputed_consumers = []
  # precompute consumers to avoid issue with repeated tensors:
  for t0, t1 in zip(ts0, ts1):
    consumers0 = set(t0.consumers())
    consumers1 = set(t1.consumers())
    precomputed_consumers.append((consumers0, consumers1))
  for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers):
    if t0 is t1:
      continue  # Silently ignore identical tensors.
    consumers0, consumers1 = consumers
    if a2b:
      nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify,
                                     cannot_modify)
    if b2a:
      nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify,
                                     cannot_modify)
  return nb_update_inputs
Exemplo n.º 37
0
def select_ops(*args, **kwargs):
  """Helper to select operations.

  Args:
    *args: list of 1) regular expressions (compiled or not) or  2) (array of)
      tf.Operation. tf.Tensor instances are silently ignored.
    **kwargs: 'graph': tf.Graph in which to perform the regex query.This is
      required when using regex.
      'positive_filter': an elem if selected only if positive_filter(elem) is
        True. This is optional.
      'restrict_regex': a regular expression is ignored if it doesn't start
        with the substring "(?#ops)".
  Returns:
    list of tf.Operation
  Raises:
    TypeError: if the optional keyword argument graph is not a tf.Graph
      or if an argument in args is not an (array of) tf.Operation
      or an (array of) tf.Tensor (silently ignored) or a string
      or a regular expression.
    ValueError: if one of the keyword arguments is unexpected or if a regular
      expression is used without passing a graph as a keyword argument.
  """
  # get keywords arguments
  graph = None
  positive_filter = None
  restrict_regex = False
  for k, v in kwargs.iteritems():
    if k == "graph":
      graph = v
      if graph is not None and not isinstance(graph, tf_ops.Graph):
        raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
    elif k == "positive_filter":
      positive_filter = v
    elif k == "restrict_regex":
      restrict_regex = v
    else:
      raise ValueError("Wrong keywords argument: {}.".format(k))

  ops = []

  for arg in args:
    if _can_be_regex(arg):
      if graph is None:
        raise ValueError("Use the keyword argument 'graph' to use regex.")
      regex = _make_regex(arg)
      if regex.pattern.startswith("(?#ts)"):
        continue
      if restrict_regex and not regex.pattern.startswith("(?#ops)"):
        continue
      ops_ = filter_ops_from_regex(graph, regex)
      for op_ in ops_:
        if op_ not in ops:
          if positive_filter is None or positive_filter(op_):
            ops.append(op_)
    else:
      ops_aux = util.make_list_of_op(arg, ignore_ts=True)
      if positive_filter is not None:
        ops_aux = [op for op in ops_aux if positive_filter(op)]
      ops_aux = [op for op in ops_aux if op not in ops]
      ops += ops_aux

  return ops
Exemplo n.º 38
0
def get_forward_walk_ops(seed_ops,
                         inclusive=True,
                         within_ops=None,
                         stop_at_ts=(),
                         control_outputs=None):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
  _, control_outputs = check_cios(False, control_outputs)
  if not util.is_iterable(seed_ops):
    seed_ops = [seed_ops]
  if not seed_ops:
    return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  seed_ops = frozenset(seed_ops)
  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops

  def is_within(op):
    return within_ops is None or op in within_ops

  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        if new_t in stop_at_ts:
          continue
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None:
        for new_op in control_outputs.get(op):
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result
Exemplo n.º 39
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          stop_at_ts=(),
                          control_inputs=False):
  """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  if not util.is_iterable(seed_ops):
    seed_ops = [seed_ops]
  if not seed_ops:
    return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_generating_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  seed_ops = frozenset(util.make_list_of_op(seed_ops))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops

  def is_within(op):
    return within_ops is None or op in within_ops

  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.inputs:
        if new_t in stop_at_ts:
          continue
        if new_t.op not in result and is_within(new_t.op):
          new_wave.add(new_t.op)
      if control_inputs:
        for new_op in op.control_inputs:
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result