Exemple #1
0
def get_within_boundary_ops(ops,
                            seed_ops,
                            boundary_ops=(),
                            inclusive=True,
                            control_inputs=False,
                            control_outputs=None,
                            control_ios=None):
    """Return all the `tf.Operation` within the given boundary.

  Args:
    ops: an object convertible to a list of `tf.Operation`. those ops define the
      set in which to perform the operation (if a `tf.Graph` is given, it
      will be converted to the list of all its operations).
    seed_ops: the operations from which to start expanding.
    boundary_ops: the ops forming the boundary.
    inclusive: if `True`, the result will also include the boundary ops.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of `util.ControlOutputs` or `None`. If not
      `None`, control outputs are enabled.
    control_ios:  An instance of `util.ControlOutputs` or `None`. If not
      `None`, both control inputs and control outputs are enabled. This is
      equivalent to set control_inputs to True and control_outputs to
      the `util.ControlOutputs` instance.
  Returns:
    All the `tf.Operation` surrounding the given ops.
  Raises:
    TypeError: if `ops` or `seed_ops` cannot be converted to a list of
      `tf.Operation`.
    ValueError: if the boundary is intersecting with the seeds.
  """
    control_inputs, control_outputs = check_cios(control_inputs,
                                                 control_outputs, control_ios)
    ops = util.make_list_of_op(ops)
    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
    boundary_ops = set(util.make_list_of_op(boundary_ops))
    res = set(seed_ops)
    if boundary_ops & res:
        raise ValueError("Boundary is intersecting with the seeds.")
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        ops_io = get_ops_ios(wave, control_inputs, control_outputs)
        for op in ops_io:
            if op in res:
                continue
            if op in boundary_ops:
                if inclusive:
                    res.add(op)
            else:
                new_wave.add(op)
        res.update(new_wave)
        wave = new_wave
    return [op for op in ops if op in res]
Exemple #2
0
def get_ops_ios(ops,
                control_inputs=False,
                control_outputs=None,
                control_ios=None):
    """Return all the `tf.Operation` which are connected to an op in ops.

  Args:
    ops: an object convertible to a list of `tf.Operation`.
    control_inputs: A boolean indicating whether control inputs are enabled.
    control_outputs: An instance of `util.ControlOutputs` or `None`. If not
      `None`, control outputs are enabled.
    control_ios:  An instance of `util.ControlOutputs` or `None`. If not `None`,
      both control inputs and control outputs are enabled. This is equivalent to
      set `control_inputs` to `True` and `control_outputs` to the
      `util.ControlOutputs` instance.
  Returns:
    All the `tf.Operation` surrounding the given ops.
  Raises:
    TypeError: if `ops` cannot be converted to a list of `tf.Operation`.
  """
    control_inputs, control_outputs = check_cios(control_inputs,
                                                 control_outputs, control_ios)
    ops = util.make_list_of_op(ops)
    res = []
    for op in ops:
        util.concatenate_unique(res, [t.op for t in op.inputs])
        for t in op.outputs:
            util.concatenate_unique(res, t.consumers())
        if control_outputs is not None:
            util.concatenate_unique(res, control_outputs.get(op))
        if control_inputs:
            util.concatenate_unique(res, op.control_inputs)
    return res
Exemple #3
0
def compute_boundary_ts(ops):
    """Compute the tensors at the boundary of a set of ops.

  This function looks at all the tensors connected to the given ops (in/out)
  and classify them into three categories:
  1) input tensors: tensors whose generating operation is not in ops.
  2) output tensors: tensors whose consumer operations are not in ops
  3) inside tensors: tensors which are neither input nor output tensors.

  Note that a tensor can be both an inside tensor and an output tensor if it is
  consumed by operations both outside and inside of `ops`.

  Args:
    ops: an object convertible to a list of `pge.Node`
  Returns:
    A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
      `outside_input_ts` is a Python list of input tensors;
      `outside_output_ts` is a python list of output tensors;
      `inside_ts` is a python list of inside tensors.
    Since a tensor can be both an inside tensor and an output tensor,
    `outside_output_ts` and `inside_ts` might intersect.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    input_ts = _get_input_ts(ops)
    output_ts = _get_output_ts(ops)
    output_ts_set = frozenset(output_ts)
    ops_set = frozenset(ops)

    # Compute inside tensors.
    inside_ts = []
    only_inside_ts = []
    for t in input_ts:
        # Skip if the input tensor is not also an output tensor.
        if t not in output_ts_set:
            continue
        # Mark as "inside".
        inside_ts.append(t)
        # Mark as "only inside" if the tensor is not both inside and output.
        consumers = frozenset(t.consumers)
        if consumers - ops_set:
            continue
        only_inside_ts.append(t)

    inside_ts_set = frozenset(inside_ts)
    only_inside_ts_set = frozenset(only_inside_ts)
    outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
    outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
    return outside_input_ts, outside_output_ts, inside_ts
Exemple #4
0
def _get_output_ts(ops):
    """Compute the list of unique output tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of tf.Operation.
  Returns:
    The list of unique output tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    ts = []
    for op in ops:
        ts += op.outputs
    return ts
Exemple #5
0
def filter_ops_from_regex(ops, regex):
    """Get all the operations that match the given regex.

  Args:
    ops: an object convertible to a list of `tf.Operation`.
    regex: a regular expression matching the operation's name.
      For example, `"^foo(/.*)?$"` will match all the operations in the "foo"
      scope.
  Returns:
    A list of `tf.Operation`.
  Raises:
    TypeError: if ops cannot be converted to a list of `tf.Operation`.
  """
    ops = util.make_list_of_op(ops)
    regex_obj = make_regex(regex)
    return filter_ops(ops, lambda op: regex_obj.search(op.name))
Exemple #6
0
def filter_ops(ops, positive_filter):
    """Get the ops passing the given filter.

  Args:
    ops: an object convertible to a list of tf.Operation.
    positive_filter: a function deciding where to keep an operation or not.
      If True, all the operations are returned.
  Returns:
    A list of selected tf.Operation.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    if positive_filter is not True:  # pylint: disable=g-explicit-bool-comparison
        ops = [op for op in ops if positive_filter(op)]
    return ops
Exemple #7
0
def filter_ts_from_regex(ops, regex):
    r"""Get all the tensors linked to ops that match the given regex.

  Args:
    ops: an object convertible to a list of tf.Operation.
    regex: a regular expression matching the tensors' name.
      For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo"
      scope.
  Returns:
    A list of tf.Tensor.
  Raises:
    TypeError: if ops cannot be converted to a list of tf.Operation.
  """
    ops = util.make_list_of_op(ops)
    regex_obj = make_regex(regex)
    return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
Exemple #8
0
def filter_ts(ops, positive_filter):
    """Get all the tensors which are input or output of an op in ops.

  Args:
    ops: an object convertible to a list of `tf.Operation`.
    positive_filter: a function deciding whether to keep a tensor or not.
      If `True`, all the tensors are returned.
  Returns:
    A list of `tf.Tensor`.
  Raises:
    TypeError: if ops cannot be converted to a list of `tf.Operation`.
  """
    ops = util.make_list_of_op(ops)
    ts = _get_input_ts(ops)
    util.concatenate_unique(ts, _get_output_ts(ops))
    if positive_filter is not True:
        ts = [t for t in ts if positive_filter(t)]
    return ts
Exemple #9
0
def _get_input_ts(ops):
    """Compute the list of unique input tensors of all the op in ops.

  Args:
    ops: an object convertible to a list of `tf.Operation`.
  Returns:
    The list of unique input tensors of all the op in ops.
  Raises:
    TypeError: if ops cannot be converted to a list of `tf.Operation`.
  """
    ops = util.make_list_of_op(ops)
    ts = []
    ts_set = set()
    for op in ops:
        for t in op.inputs:
            if t not in ts_set:
                ts.append(t)
                ts_set.add(t)
    return ts
Exemple #10
0
  def __init__(self, inside_ops=(), passthrough_ts=()):
    """Create a subgraph containing the given ops and the "passthrough" tensors.

    Args:
      inside_ops: an object convertible to a list of `tf.Operation`. This list
        defines all the operations in the subgraph.
      passthrough_ts: an object convertible to a list of `tf.Tensor`. This list
        define all the "passthrough" tensors. A passthrough tensor is a tensor
        which goes directly from the input of the subgraph to it output, without
        any intermediate operations. All the non passthrough tensors are
        silently ignored.
    Raises:
      TypeError: if inside_ops cannot be converted to a list of `tf.Operation`
        or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`.
    """

    inside_ops = util.make_list_of_op(inside_ops)
    passthrough_ts = util.make_list_of_t(passthrough_ts)
    ops_and_ts = inside_ops + passthrough_ts
    if ops_and_ts:
      self._graph = util.get_unique_graph(ops_and_ts)
      self._ops = inside_ops

      # Compute inside and outside tensor
      inputs, outputs, insides = select.compute_boundary_ts(inside_ops)

      # Compute passthrough tensors, silently ignoring the non-passthrough ones.
      all_tensors = frozenset(inputs + outputs + list(insides))
      self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]

      # Set inputs and outputs.
      self._input_ts = inputs + self._passthrough_ts
      self._output_ts = outputs + self._passthrough_ts
    else:
      self._graph = None
      self._passthrough_ts = []
      self._input_ts = []
      self._output_ts = []
      self._ops = []
Exemple #11
0
def select_ops(*args, **kwargs):
    """Helper to select operations.

  Args:
    *args: list of 1) regular expressions (compiled or not) or 2) (array of)
      `tf.Operation`. `tf.Tensor` instances are silently ignored.
    **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
      required when using regex.
      'positive_filter': an elem if selected only if `positive_filter(elem)` is
        `True`. This is optional.
      'restrict_ops_regex': a regular expression is ignored if it doesn't start
        with the substring "(?#ops)".
  Returns:
    A list of `tf.Operation`.
  Raises:
    TypeError: if the optional keyword argument graph is not a `tf.Graph`
      or if an argument in args is not an (array of) `tf.Operation`
      or an (array of) `tf.Tensor` (silently ignored) or a string
      or a regular expression.
    ValueError: if one of the keyword arguments is unexpected or if a regular
      expression is used without passing a graph as a keyword argument.
  """
    # get keywords arguments
    graph = None
    positive_filter = None
    restrict_ops_regex = False
    for k, v in iteritems(kwargs):
        if k == "graph":
            graph = v
            if graph is not None and not isinstance(graph, tf_ops.Graph):
                raise TypeError("Expected a tf.Graph, got: {}".format(
                    type(graph)))
        elif k == "positive_filter":
            positive_filter = v
        elif k == "restrict_ops_regex":
            restrict_ops_regex = v
        elif k == "restrict_ts_regex":
            pass
        else:
            raise ValueError("Wrong keywords argument: {}.".format(k))

    ops = []

    for arg in args:
        if can_be_regex(arg):
            if graph is None:
                raise ValueError(
                    "Use the keyword argument 'graph' to use regex.")
            regex = make_regex(arg)
            if regex.pattern.startswith("(?#ts)"):
                continue
            if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"):
                continue
            ops_ = filter_ops_from_regex(graph, regex)
            for op_ in ops_:
                if op_ not in ops:
                    if positive_filter is None or positive_filter(op_):
                        ops.append(op_)
        else:
            ops_aux = util.make_list_of_op(arg, ignore_ts=True)
            if positive_filter is not None:
                ops_aux = [op for op in ops_aux if positive_filter(op)]
            ops_aux = [op for op in ops_aux if op not in ops]
            ops += ops_aux

    return ops
Exemple #12
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          within_ops_fn=None,
                          stop_at_ts=(),
                          control_inputs=False):
    """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    within_ops_fn: if provided, a function on ops that should return True iff
      the op is within the graph traversal. This can be used along within_ops,
      in which case an op is within if it is also in within_ops.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the `tf.Operation` behind `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
    if not util.is_iterable(seed_ops):
        seed_ops = [seed_ops]
    if not seed_ops:
        return []
    if isinstance(seed_ops[0], tf_ops.Tensor):
        ts = util.make_list_of_t(seed_ops, allow_graph=False)
        seed_ops = util.get_generating_ops(ts)
    else:
        seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)

    stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
    seed_ops = frozenset(util.make_list_of_op(seed_ops))
    if within_ops:
        within_ops = util.make_list_of_op(within_ops, allow_graph=False)
        within_ops = frozenset(within_ops)
        seed_ops &= within_ops

    def is_within(op):
        return (within_ops is None or op
                in within_ops) and (within_ops_fn is None or within_ops_fn(op))

    result = list(seed_ops)
    wave = set(seed_ops)
    while wave:
        new_wave = set()
        for op in wave:
            for new_t in op.inputs:
                if new_t in stop_at_ts:
                    continue
                if new_t.op not in result and is_within(new_t.op):
                    new_wave.add(new_t.op)
            if control_inputs:
                for new_op in op.control_inputs:
                    if new_op not in result and is_within(new_op):
                        new_wave.add(new_op)
        util.concatenate_unique(result, new_wave)
        wave = new_wave
    if not inclusive:
        result = [op for op in result if op not in seed_ops]
    return result