Esempio n. 1
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None,
                         control_outputs=True):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    control_outputs: an object convertible to a control output dictionary
      (see function util.convert_to_control_outputs for more details).
      If the dictionary can be created, it will be used while walking the graph
      forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
  if not seed_ops: return set()
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs)

  seed_ops = frozenset(seed_ops)
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops
  def is_within(op):
    return within_ops is None or op in within_ops
  result = set(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None and op in control_outputs:
        for new_op in control_outputs[op]:
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    result.update(new_wave)
    wave = new_wave
  if not inclusive:
    result.difference_update(seed_ops)
  return result
Esempio n. 2
0
def get_within_boundary_ops(ops,
                            seed_ops,
                            boundary_ops,
                            inclusive=True,
                            control_outputs=True):
    """Return all the tf.Operation within the given boundary.

  Args:
    ops: an object convertible to a list of tf.Operation. those ops define the
      set in which to perform the operation (if a tf.Graph is given, it
      will be converted to the list of all its operations).
    seed_ops: the operations from which to start expanding.
    boundary_ops: the ops forming the boundary.
    inclusive: if True, the result will also include the boundary ops.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used while
      expanding.
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops or seed_ops cannot be converted to a list of tf.Operation.
    ValueError: if the boundary is intersecting with the seeds.
  """
    ops = util.make_list_of_op(ops)
    control_outputs = util.convert_to_control_outputs(ops, control_outputs)
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
    boundary_ops = set(util.make_list_of_op(boundary_ops))
    res = set(seed_ops)
    if boundary_ops & res:
        raise ValueError("Boundary is intersecting with the seeds.")
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        ops_io = get_ops_ios(wave, control_outputs)
        for op in ops_io:
            if op in res:
                continue
            if op in boundary_ops:
                if inclusive:
                    res.add(op)
            else:
                new_wave.add(op)
        res.update(new_wave)
        wave = new_wave
    return res
Esempio n. 3
0
def get_within_boundary_ops(ops,
                            seed_ops,
                            boundary_ops,
                            inclusive=True,
                            control_outputs=True):
  """Return all the tf.Operation within the given boundary.

  Args:
    ops: an object convertible to a list of tf.Operation. those ops define the
      set in which to perform the operation (if a tf.Graph is given, it
      will be converted to the list of all its operations).
    seed_ops: the operations from which to start expanding.
    boundary_ops: the ops forming the boundary.
    inclusive: if True, the result will also include the boundary ops.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used while
      expanding.
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops or seed_ops cannot be converted to a list of tf.Operation.
    ValueError: if the boundary is intersecting with the seeds.
  """
  ops = util.make_list_of_op(ops)
  control_outputs = util.convert_to_control_outputs(ops, control_outputs)
  seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
  boundary_ops = set(util.make_list_of_op(boundary_ops))
  res = set(seed_ops)
  if boundary_ops & res:
    raise ValueError("Boundary is intersecting with the seeds.")
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    ops_io = get_ops_ios(wave, control_outputs)
    for op in ops_io:
      if op in res:
        continue
      if op in boundary_ops:
        if inclusive:
          res.add(op)
      else:
        new_wave.add(op)
    res.update(new_wave)
    wave = new_wave
  return res
Esempio n. 4
0
def get_ops_ios(ops, control_outputs=True):
    """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used to determine
      the surrounding ops (in addition to the regular inputs and outputs).
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    control_outputs = util.convert_to_control_outputs(ops, control_outputs)
    ops = util.make_list_of_op(ops)
    res = set()
    for op in ops:
        res.update([t.op for t in op.inputs])
        for t in op.outputs:
            res.update(t.consumers())
        if control_outputs is not None and op in control_outputs:
            res.update(control_outputs[op])
    return res
Esempio n. 5
0
def get_ops_ios(ops, control_outputs=True):
  """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_outputs: an object convertible to a control output dictionary
      (or None). If the dictionary can be created, it will be used to determine
      the surrounding ops (in addition to the regular inputs and outputs).
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  control_outputs = util.convert_to_control_outputs(ops, control_outputs)
  ops = util.make_list_of_op(ops)
  res = set()
  for op in ops:
    res.update([t.op for t in op.inputs])
    for t in op.outputs:
      res.update(t.consumers())
    if control_outputs is not None and op in control_outputs:
      res.update(control_outputs[op])
  return res