Ejemplo n.º 1
0
    def test_backward_walk_ops(self):
        seed_ops = [self.h.op]
        # Include all ops except for self.g.op
        within_ops = [
            x.op
            for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
        ]
        # For the fn, exclude self.c.op.
        within_ops_fn = lambda op: op not in (self.c.op, )
        stop_at_ts = (self.f, )

        with self.graph.as_default():
            # Backward walk only includes h since we stop at f and g is not within.
            ops = op_selector.get_backward_walk_ops(
                seed_ops,
                inclusive=True,
                within_ops=within_ops,
                within_ops_fn=within_ops_fn,
                stop_at_ts=stop_at_ts)
            self.assertEqual(set(ops), set([self.h.op]))

            # If we do inclusive=False, the result is empty.
            ops = op_selector.get_backward_walk_ops(
                seed_ops,
                inclusive=False,
                within_ops=within_ops,
                within_ops_fn=within_ops_fn,
                stop_at_ts=stop_at_ts)
            self.assertEqual(set(ops), set())

            # Removing stop_at_fs adds f.op, d.op.
            ops = op_selector.get_backward_walk_ops(
                seed_ops,
                inclusive=True,
                within_ops=within_ops,
                within_ops_fn=within_ops_fn)
            self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))

            # Not using within_ops_fn adds back ops for a, b, c.
            ops = op_selector.get_backward_walk_ops(seed_ops,
                                                    inclusive=True,
                                                    within_ops=within_ops)
            self.assertEqual(
                set(ops),
                set([
                    self.a.op, self.b.op, self.c.op, self.d.op, self.f.op,
                    self.h.op
                ]))

            # Vanially backward search via self.h.op includes everything excpet e.op.
            ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True)
            self.assertEqual(
                set(ops),
                set([
                    self.a.op, self.b.op, self.c.op, self.d.op, self.f.op,
                    self.g.op, self.h.op
                ]))
Ejemplo n.º 2
0
def get_backward_walk_ops(seed_ops,
                          inclusive=True,
                          within_ops=None,
                          within_ops_fn=None,
                          stop_at_ts=(),
                          control_inputs=False):
  """Do a backward graph walk and return all the visited ops.

  Args:
    seed_ops: an iterable of operations from which the backward graph
      walk starts. If a list of tensors is given instead, the seed_ops are set
      to be the generators of those tensors.
    inclusive: if True the given seed_ops are also part of the resulting set.
    within_ops: an iterable of `tf.Operation` within which the search is
      restricted. If `within_ops` is `None`, the search is performed within
      the whole graph.
    within_ops_fn: if provided, a function on ops that should return True iff
      the op is within the graph traversal. This can be used along within_ops,
      in which case an op is within if it is also in within_ops.
    stop_at_ts: an iterable of tensors at which the graph walk stops.
    control_inputs: if True, control inputs will be used while moving backward.
  Returns:
    A Python set of all the `tf.Operation` behind `seed_ops`.
  Raises:
    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
      `tf.Operation`.
  """
  return op_selector.get_backward_walk_ops(
      seed_ops,
      inclusive=inclusive,
      within_ops=within_ops,
      within_ops_fn=within_ops_fn,
      stop_at_ts=stop_at_ts,
      control_inputs=control_inputs)
Ejemplo n.º 3
0
def get_dependent_variables(input_ops, output_ops):
    """Finds variables involved in the subgraph b/w input_ops and output_ops."""

    # avoids the edge-case when input_ops == output_ops.
    output_ops = nest.map_structure(gen_array_ops.identity, output_ops)

    inbetween_ops = op_selector.get_backward_walk_ops(seed_ops=output_ops,
                                                      stop_at_ts=input_ops,
                                                      inclusive=False)
    var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
    var_names = (op.name for op in var_ops)
    tf_vars = [get_variable_by_name(var_name) for var_name in var_names]
    return tf_vars
Ejemplo n.º 4
0
def _ensure_servable(input_tensors, names_to_output_tensor_infos):
  """Check that the signature outputs don't depend on unreachable placeholders.

  Args:
    input_tensors: An iterable of `Tensor`s specified as the signature's inputs.
    names_to_output_tensor_infos: An mapping from output names to respective
      `TensorInfo`s corresponding to the signature's output tensors.

  Raises:
    ValueError: If any of the signature's outputs depend on placeholders not
      provided as signature's inputs.
  """
  plain_input_tensors = nest.flatten(input_tensors, expand_composites=True)

  graph = op_selector.get_unique_graph(plain_input_tensors)

  output_tensors = [
      utils.get_tensor_from_tensor_info(tensor, graph=graph)
      for tensor in names_to_output_tensor_infos.values()
  ]
  plain_output_tensors = nest.flatten(output_tensors, expand_composites=True)

  dependency_ops = op_selector.get_backward_walk_ops(
      plain_output_tensors, stop_at_ts=plain_input_tensors)

  fed_tensors = object_identity.ObjectIdentitySet(plain_input_tensors)
  for dependency_op in dependency_ops:
    if _must_be_fed(dependency_op) and (not all(
        output in fed_tensors for output in dependency_op.outputs)):
      input_tensor_names = [tensor.name for tensor in plain_input_tensors]
      output_tensor_keys = list(names_to_output_tensor_infos.keys())
      output_tensor_names = [tensor.name for tensor in plain_output_tensors]
      dependency_path = op_selector.show_path(dependency_op,
                                              plain_output_tensors,
                                              plain_input_tensors)
      raise ValueError(
          f'The signature\'s input tensors {input_tensor_names} are '
          f'insufficient to compute its output keys {output_tensor_keys} '
          f'(respectively, tensors {output_tensor_names}) because of the '
          f'dependency on `{dependency_op.name}` which is not given as '
          'a signature input, as illustrated by the following dependency path: '
          f'{dependency_path}')
Ejemplo n.º 5
0
def _get_dependent_variables(input_ops, output_ops):
    """Finds variables involved in the subgraph between input_ops and output_ops.

  Args:
    input_ops: Flattened list of input ops
    output_ops: Flattened list of output ops

  Returns:
    A list of variables
  """

    # avoids the edge-case when input_ops == output_ops.
    output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
    inbetween_ops = op_selector.get_backward_walk_ops(seed_ops=output_ops,
                                                      stop_at_ts=input_ops,
                                                      inclusive=False,
                                                      only_differentiable=True)
    var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
    var_names = (op.name for op in var_ops)
    tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
    tf_vars = [v for v in tf_vars if v is not None]
    return tf_vars