def test_backward_walk_ops(self): seed_ops = [self.h.op] # Include all ops except for self.g.op within_ops = [ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] ] # For the fn, exclude self.c.op. within_ops_fn = lambda op: op not in (self.c.op, ) stop_at_ts = (self.f, ) with self.graph.as_default(): # Backward walk only includes h since we stop at f and g is not within. ops = op_selector.get_backward_walk_ops( seed_ops, inclusive=True, within_ops=within_ops, within_ops_fn=within_ops_fn, stop_at_ts=stop_at_ts) self.assertEqual(set(ops), set([self.h.op])) # If we do inclusive=False, the result is empty. ops = op_selector.get_backward_walk_ops( seed_ops, inclusive=False, within_ops=within_ops, within_ops_fn=within_ops_fn, stop_at_ts=stop_at_ts) self.assertEqual(set(ops), set()) # Removing stop_at_fs adds f.op, d.op. ops = op_selector.get_backward_walk_ops( seed_ops, inclusive=True, within_ops=within_ops, within_ops_fn=within_ops_fn) self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op])) # Not using within_ops_fn adds back ops for a, b, c. ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True, within_ops=within_ops) self.assertEqual( set(ops), set([ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op ])) # Vanially backward search via self.h.op includes everything excpet e.op. ops = op_selector.get_backward_walk_ops(seed_ops, inclusive=True) self.assertEqual( set(ops), set([ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op, self.h.op ]))
def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, within_ops_fn=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. Args: seed_ops: an iterable of operations from which the backward graph walk starts. If a list of tensors is given instead, the seed_ops are set to be the generators of those tensors. inclusive: if True the given seed_ops are also part of the resulting set. within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. within_ops_fn: if provided, a function on ops that should return True iff the op is within the graph traversal. This can be used along within_ops, in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: A Python set of all the `tf.Operation` behind `seed_ops`. Raises: TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of `tf.Operation`. """ return op_selector.get_backward_walk_ops( seed_ops, inclusive=inclusive, within_ops=within_ops, within_ops_fn=within_ops_fn, stop_at_ts=stop_at_ts, control_inputs=control_inputs)
def get_dependent_variables(input_ops, output_ops): """Finds variables involved in the subgraph b/w input_ops and output_ops.""" # avoids the edge-case when input_ops == output_ops. output_ops = nest.map_structure(gen_array_ops.identity, output_ops) inbetween_ops = op_selector.get_backward_walk_ops(seed_ops=output_ops, stop_at_ts=input_ops, inclusive=False) var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) var_names = (op.name for op in var_ops) tf_vars = [get_variable_by_name(var_name) for var_name in var_names] return tf_vars
def _ensure_servable(input_tensors, names_to_output_tensor_infos): """Check that the signature outputs don't depend on unreachable placeholders. Args: input_tensors: An iterable of `Tensor`s specified as the signature's inputs. names_to_output_tensor_infos: An mapping from output names to respective `TensorInfo`s corresponding to the signature's output tensors. Raises: ValueError: If any of the signature's outputs depend on placeholders not provided as signature's inputs. """ plain_input_tensors = nest.flatten(input_tensors, expand_composites=True) graph = op_selector.get_unique_graph(plain_input_tensors) output_tensors = [ utils.get_tensor_from_tensor_info(tensor, graph=graph) for tensor in names_to_output_tensor_infos.values() ] plain_output_tensors = nest.flatten(output_tensors, expand_composites=True) dependency_ops = op_selector.get_backward_walk_ops( plain_output_tensors, stop_at_ts=plain_input_tensors) fed_tensors = object_identity.ObjectIdentitySet(plain_input_tensors) for dependency_op in dependency_ops: if _must_be_fed(dependency_op) and (not all( output in fed_tensors for output in dependency_op.outputs)): input_tensor_names = [tensor.name for tensor in plain_input_tensors] output_tensor_keys = list(names_to_output_tensor_infos.keys()) output_tensor_names = [tensor.name for tensor in plain_output_tensors] dependency_path = op_selector.show_path(dependency_op, plain_output_tensors, plain_input_tensors) raise ValueError( f'The signature\'s input tensors {input_tensor_names} are ' f'insufficient to compute its output keys {output_tensor_keys} ' f'(respectively, tensors {output_tensor_names}) because of the ' f'dependency on `{dependency_op.name}` which is not given as ' 'a signature input, as illustrated by the following dependency path: ' f'{dependency_path}')
def _get_dependent_variables(input_ops, output_ops): """Finds variables involved in the subgraph between input_ops and output_ops. Args: input_ops: Flattened list of input ops output_ops: Flattened list of output ops Returns: A list of variables """ # avoids the edge-case when input_ops == output_ops. output_ops = nest.map_structure(gen_array_ops.identity, output_ops) inbetween_ops = op_selector.get_backward_walk_ops(seed_ops=output_ops, stop_at_ts=input_ops, inclusive=False, only_differentiable=True) var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES) var_names = (op.name for op in var_ops) tf_vars = (get_variable_by_name(var_name) for var_name in var_names) tf_vars = [v for v in tf_vars if v is not None] return tf_vars