Example #1
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          stop_at_ts=(),
                          control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops): seed_ops = [seed_ops]
    if not seed_ops: return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Example #2
0
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of tf.Operation whithin which the search is
      restricted. If within_ops is None, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the tf.Operation behind seed_ops.
  Raises:
    TypeError: if seed_ops or within_ops cannot be converted to a list of
      tf.Operation.
  """
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result
Example #3
0
def get_ops_ios(ops, control_inputs=False, control_outputs=None,
                control_ios=None):
  """Return all the tf.Operation which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of util.ControlOutputs or None. If not None,
      control outputs are enabled.
    control_ios:  An instance of util.ControlOutputs or None. If not None, both
      control inputs and control outputs are enabled. This is equivalent to set
      control_inputs to True and control_outputs to the util.ControlOutputs
      instance.
  Returns:
    All the tf.Operation surrounding the given ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
                                               control_ios)
  ops = util.make_list_of_op(ops)
  res = []
  for op in ops:
    util.concatenate_unique(res, [t.op for t in op.inputs])
    for t in op.outputs:
      util.concatenate_unique(res, t.consumers())
    if control_outputs is not None:
      util.concatenate_unique(res, control_outputs.get(op))
    if control_inputs:
      util.concatenate_unique(res, op.control_inputs)
  return res
Example #4
0
    def consumers(self):
        """Return a Python set of all the consumers of this subgraph view.

    A consumer of a subgraph view is a tf.Operation which is a consumer
    of one of the output tensors and is not in the subgraph.

    Returns:
      A list of `tf.Operation` which are the consumers of this subgraph view.
    """
        ops_set = frozenset(self._ops)
        res = []
        for output in self._output_ts:
            consumers = [op for op in output.consumers() if op not in ops_set]
            util.concatenate_unique(res, consumers)
        return res
Example #5
0
  def consumers(self):
    """Return a Python set of all the consumers of this subgraph view.

    A consumer of a subgraph view is a tf.Operation which is a consumer
    of one of the output tensors and is not in the subgraph.

    Returns:
      A list of `tf.Operation` which are the consumers of this subgraph view.
    """
    ops_set = frozenset(self._ops)
    res = []
    for output in self._output_ts:
      consumers = [op for op in output.consumers() if op not in ops_set]
      util.concatenate_unique(res, consumers)
    return res
Example #6
0
def filter_ts(ops, positive_filter):
  """Get all the tensors which are input or output of an op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding whether to keep a tensor or not.
      If True, all the tensors are returned.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
  ops = util.make_list_of_op(ops)
  ts = _get_input_ts(ops)
  util.concatenate_unique(ts, _get_output_ts(ops))
  if positive_filter is not True:
    ts = [t for t in ts if positive_filter(t)]
  return ts
def get_walks_union_ops(forward_seed_ops,
                        backward_seed_ops,
                        forward_inclusive=True,
                        backward_inclusive=True,
                        within_ops=None,
                        within_ops_fn=None,
                        control_inputs=False,
                        control_outputs=None,
                        control_ios=None):
  """Return the union of a forward and a backward walk.

  Args:
    forward_seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    backward_seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    forward_inclusive: if True the given forward_seed_ops are also part of the
      resulting set.
    backward_inclusive: if True the given backward_seed_ops are also part of the
      resulting set.
    within_ops: restrict the search within those operations. If within_ops is
      None, the search is done within the whole graph.
    within_ops_fn: if provided, a function on ops that should return True iff
      the op is within the graph traversal. This can be used along within_ops,
      in which case an op is within if it is also in within_ops.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of util.ControlOutputs or None. If not None,
      control outputs are enabled.
    control_ios:  An instance of util.ControlOutputs or None. If not None, both
      control inputs and control outputs are enabled. This is equivalent to set
      control_inputs to True and control_outputs to the util.ControlOutputs
      instance.
  Returns:
    A Python set of all the tf.Operation in the union of a forward and a
      backward walk.
  Raises:
    TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be
      converted to a list of tf.Operation.
  """
  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
                                               control_ios)
  forward_ops = get_forward_walk_ops(
      forward_seed_ops,
      inclusive=forward_inclusive,
      within_ops=within_ops,
      within_ops_fn=within_ops_fn,
      control_outputs=control_outputs)
  backward_ops = get_backward_walk_ops(
      backward_seed_ops,
      inclusive=backward_inclusive,
      within_ops=within_ops,
      within_ops_fn=within_ops_fn,
      control_inputs=control_inputs)
  return util.concatenate_unique(forward_ops, backward_ops)
Example #8
0
def get_walks_union_ops(forward_seed_ops,
                        backward_seed_ops,
                        forward_inclusive=True,
                        backward_inclusive=True,
                        within_ops=None,
                        within_ops_fn=None,
                        control_inputs=False,
                        control_outputs=None,
                        control_ios=None):
  """Return the union of a forward and a backward walk.

  Args:
    forward_seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    backward_seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    forward_inclusive: if True the given forward_seed_ops are also part of the
      resulting set.
    backward_inclusive: if True the given backward_seed_ops are also part of the
      resulting set.
    within_ops: restrict the search within those operations. If within_ops is
      None, the search is done within the whole graph.
    within_ops_fn: if provided, a function on ops that should return True iff
      the op is within the graph traversal. This can be used along within_ops,
      in which case an op is within if it is also in within_ops.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of util.ControlOutputs or None. If not None,
      control outputs are enabled.
    control_ios:  An instance of util.ControlOutputs or None. If not None, both
      control inputs and control outputs are enabled. This is equivalent to set
      control_inputs to True and control_outputs to the util.ControlOutputs
      instance.
  Returns:
    A Python set of all the tf.Operation in the union of a forward and a
      backward walk.
  Raises:
    TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be
      converted to a list of tf.Operation.
  """
  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
                                               control_ios)
  forward_ops = get_forward_walk_ops(
      forward_seed_ops,
      inclusive=forward_inclusive,
      within_ops=within_ops,
      within_ops_fn=within_ops_fn,
      control_outputs=control_outputs)
  backward_ops = get_backward_walk_ops(
      backward_seed_ops,
      inclusive=backward_inclusive,
      within_ops=within_ops,
      within_ops_fn=within_ops_fn,
      control_inputs=control_inputs)
  return util.concatenate_unique(forward_ops, backward_ops)
Example #9
0
 def consumers(self):
   """Return a Python set of all the consumers of this subgraph view."""
   res = []
   for output in self._output_ts:
     util.concatenate_unique(res, output.consumers())
   return res
Example #10
0
 def _remap_outputs_make_unique(self):
   """Remap the outputs in place so that all the tensors appears only once."""
   output_ts = list(self._output_ts)
   self._output_ts = []
   util.concatenate_unique(self._output_ts, output_ts)
Example #11
0
 def _remap_outputs_make_unique(self):
     """Remap the outputs in place so that all the tensors appears only once."""
     output_ts = list(self._output_ts)
     self._output_ts = []
     util.concatenate_unique(self._output_ts, output_ts)
Example #12
0
 def consumers(self):
     """Return a Python set of all the consumers of this subgraph view."""
     res = []
     for output in self._output_ts:
         util.concatenate_unique(res, output.consumers())
     return res
def get_forward_walk_ops(seed_ops,
                         inclusive=True,
                         within_ops=None,
                         stop_at_ts=(),
                         control_outputs=None):
  """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
  _, control_outputs = check_cios(False, control_outputs)
  if not util.is_iterable(seed_ops):
    seed_ops = [seed_ops]
  if not seed_ops:
    return []
  if isinstance(seed_ops[0], tf_ops.Tensor):
    ts = util.make_list_of_t(seed_ops, allow_graph=False)
    seed_ops = util.get_consuming_ops(ts)
  else:
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

  seed_ops = frozenset(seed_ops)
  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
  if within_ops:
    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
    within_ops = frozenset(within_ops)
    seed_ops &= within_ops

  def is_within(op):
    return within_ops is None or op in within_ops

  result = list(seed_ops)
  wave = set(seed_ops)
  while wave:
    new_wave = set()
    for op in wave:
      for new_t in op.outputs:
        if new_t in stop_at_ts:
          continue
        for new_op in new_t.consumers():
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
      if control_outputs is not None:
        for new_op in control_outputs.get(op):
          if new_op not in result and is_within(new_op):
            new_wave.add(new_op)
    util.concatenate_unique(result, new_wave)
    wave = new_wave
  if not inclusive:
    result = [op for op in result if op not in seed_ops]
  return result
Example #14
0
def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, stop_at_ts=(), control_outputs=None):
    """Do a forward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the forward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the consumers of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_outputs: a `util.ControlOutputs` instance or None.
      If not `None`, it will be used while walking the graph forward.
  Returns:
    A Python set of all the `tf.Operation` ahead of `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
    _, control_outputs = check_cios(False, control_outputs)
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_consuming_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    seed_ops = frozenset(seed_ops)
    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return within_ops is None or op in within_ops

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.outputs:
                if new_t in stop_at_ts:
                    continue
                for new_op in new_t.consumers():
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
            if control_outputs is not None:
                for new_op in control_outputs.get(op):
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result