Esempio n. 1
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          stop_at_ts=(),
                          control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
    if not seed_ops: return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Esempio n. 2
0
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Esempio n. 3
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None,
                         control_outputs=True):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    control_outputs: an object convertible to a control output dictionary
      (see function util.convert_to_control_outputs for more details).
      If the dictionary can be created, it will be used while walking the graph
      forward.
  Returns:
    A Python set of all the tf.Operation ahead of seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
  if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
  if not seed_ops: return set()
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  control_outputs = util.convert_to_control_outputs(seed_ops, control_outputs)

  seed_ops = frozenset(seed_ops)
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops
  def is_within(op):
    return within_ops is None or op in within_ops
  result = set(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None and op in control_outputs:
        for new_op in control_outputs[op]:
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    result.update(new_wave)
    wave = new_wave
  if not inclusive:
    result.difference_update(seed_ops)
  return result
Esempio n. 4
0
def get_forward_walk_ops(seed_ops,
                         inclusive=True,
                         within_ops=None,
                         stop_at_ts=(),
                         control_outputs=None):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
  _, control_outputs = check_cios(False, control_outputs)
  if not util.is_iterable(seed_ops):
    seed_ops = [seed_ops]
  if not seed_ops:
    return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  seed_ops = frozenset(seed_ops)
  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops

  def is_within(op):
    return within_ops is None or op in within_ops

  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        if new_t in stop_at_ts:
          continue
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None:
        for new_op in control_outputs.get(op):
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result
Esempio n. 5
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None):
    """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
    _, control_outputs = check_cios(False, control_outputs)
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_consuming_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    seed_ops = frozenset(seed_ops)
    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.outputs:
                if new_t in stop_at_ts:
                    continue
                for new_op in new_t.consumers():
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
            if control_outputs is not None:
                for new_op in control_outputs.get(op):
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result