def _map_subgraph(sources, sinks): """Captures subgraph between sources and sinks. Walk a Graph backwards from `sinks` to `sources` and returns any extra sources encountered in the subgraph that were not specified in `sources`. Arguments: sources: An iterable of Tensors where subgraph extraction should stop. sinks: An iterable of Operations where the subgraph terminates. Returns: The set of placeholders upon which `sinks` depend and are not in `sources`. """ stop_at_tensors = object_identity.ObjectIdentitySet(sources) ops_to_visit = object_identity.ObjectIdentitySet(sinks) visited_ops = object_identity.ObjectIdentitySet() potential_extra_sources = object_identity.ObjectIdentitySet() while ops_to_visit: op = ops_to_visit.pop() visited_ops.add(op) if op.type == 'Placeholder': potential_extra_sources.update(op.outputs) input_ops = [t.op for t in op.inputs if t not in stop_at_tensors] for input_op in itertools.chain(input_ops, op.control_inputs): if input_op not in visited_ops: ops_to_visit.add(input_op) return potential_extra_sources.difference(sources)
def retrieve_sources(sinks, ignore_control_dependencies=False): """Captures subgraph between sources and sinks. Walk a Graph backwards from `sinks` and return any sources encountered in the subgraph. This util is refactored from `_map_subgraph` in tensorflow/.../ops/op_selector.py. Arguments: sinks: An iterable of Operations where the subgraph terminates. ignore_control_dependencies: (Optional) If `True`, ignore any `control_inputs` for all ops while walking the graph. Returns: The set of placeholders upon which `sinks` depend. This could also contain placeholders representing `captures` in the graph. """ stop_at_tensors = object_identity.ObjectIdentitySet() ops_to_visit = object_identity.ObjectIdentitySet(sinks) visited_ops = object_identity.ObjectIdentitySet() potential_extra_sources = object_identity.ObjectIdentitySet() while ops_to_visit: op = ops_to_visit.pop() visited_ops.add(op) if op.type == 'Placeholder': potential_extra_sources.update(op.outputs) input_ops = [t.op for t in op.inputs if t not in stop_at_tensors] if not ignore_control_dependencies: input_ops = itertools.chain(input_ops, op.control_inputs) for input_op in input_ops: if input_op not in visited_ops: ops_to_visit.add(input_op) return potential_extra_sources
def _call_func(self, args, kwargs): try: vars_at_start = self._template_store.variables() trainable_at_start = self._template_store.trainable_variables() if self._variables_created: result = self._func(*args, **kwargs) else: # The first time we run, restore variables if necessary (via # Trackable). with trackable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. trainable_variables = self._template_store.trainable_variables( ) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if len(trainable_at_start) != len(trainable_variables): raise ValueError( "Trainable variable created when calling a template " "after the first time, perhaps you used tf.Variable " "when you meant tf.get_variable: %s" % list( object_identity.ObjectIdentitySet( trainable_variables) - object_identity. ObjectIdentitySet(trainable_at_start))) # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. variables = self._template_store.variables() if len(vars_at_start) != len(variables): logging.info( "New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " "meant tf.get_variable: %s", list( object_identity.ObjectIdentitySet(variables) - object_identity.ObjectIdentitySet(vars_at_start))) else: self._variables_created = True return result except Exception as exc: # Reraise the exception, but append the original definition to the # trace. args = exc.args if not args: arg0 = "" else: arg0 = args[0] trace = "".join( _skip_common_stack_elements(self._stacktrace, traceback.format_stack())) arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) new_args = [arg0] new_args.extend(args[1:]) exc.args = tuple(new_args) raise
def _get_resource_inputs(op): """Returns an iterable of resources touched by this `op`.""" reads = object_identity.ObjectIdentitySet() writes = object_identity.ObjectIdentitySet() for t in op.inputs: if t.dtype == dtypes_module.resource: if utils.op_writes_to_resource(t, op): writes.add(t) else: reads.add(t) saturated = False while not saturated: saturated = True for key in _acd_resource_resolvers_registry.list(): # Resolvers should return true if they are updating the list of # resource_inputs. # TODO(srbs): An alternate would be to just compare the old and new set # but that may not be as fast. updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) if updated: # Conservatively remove any resources from `reads` that are also writes. reads = reads.difference(writes) saturated = saturated and not updated # Note: A resource handle that is not written to is treated as read-only. We # don't have a special way of denoting an unused resource. return ([(t, ResourceType.READ_ONLY) for t in reads] + [(t, ResourceType.READ_WRITE) for t in writes])
def get_reachable_from_inputs(inputs, targets=None): """Returns the set of tensors/ops reachable from `inputs`. Stops if all targets have been found (target is optional). Only valid in Symbolic mode, not Eager mode. Args: inputs: List of tensors. targets: List of tensors. Returns: A set of tensors reachable from the inputs (includes the inputs themselves). """ inputs = nest.flatten(inputs, expand_composites=True) reachable = object_identity.ObjectIdentitySet(inputs) if targets: remaining_targets = object_identity.ObjectIdentitySet( nest.flatten(targets)) queue = inputs[:] while queue: x = queue.pop() if isinstance(x, tuple(_user_convertible_tensor_types)): # Can't find consumers of user-specific types. continue if isinstance(x, ops.Operation): outputs = x.outputs[:] or [] outputs += x._control_outputs # pylint: disable=protected-access elif isinstance(x, variables.Variable): try: outputs = [x.op] except AttributeError: # Variables can be created in an Eager context. outputs = [] elif tensor_util.is_tensor(x): outputs = x.consumers() else: if not isinstance(x, str): raise TypeError( 'Expected Operation, Variable, or Tensor, got ' + str(x)) for y in outputs: if y not in reachable: reachable.add(y) if targets: remaining_targets.discard(y) queue.insert(0, y) if targets and not remaining_targets: return reachable return reachable
def testDifference(self): class Element(object): pass a = Element() b = Element() c = Element() set1 = object_identity.ObjectIdentitySet([a, b]) set2 = object_identity.ObjectIdentitySet([b, c]) diff_set = set1.difference(set2) self.assertIn(a, diff_set) self.assertNotIn(b, diff_set) self.assertNotIn(c, diff_set)
def _construct_concrete_function(func, output_graph_def, converted_input_indices): """Constructs a concrete function from the `output_graph_def`. Args: func: ConcreteFunction output_graph_def: GraphDef proto. converted_input_indices: Set of integers of input indices that were converted to constants. Returns: ConcreteFunction. """ # Create a ConcreteFunction from the new GraphDef. input_tensors = func.graph.internal_captures converted_inputs = object_identity.ObjectIdentitySet( [input_tensors[index] for index in converted_input_indices]) not_converted_inputs = [ tensor for tensor in func.inputs if tensor not in converted_inputs] not_converted_inputs_map = { tensor.name: tensor for tensor in not_converted_inputs } new_input_names = [tensor.name for tensor in not_converted_inputs] new_output_names = [tensor.name for tensor in func.outputs] new_func = wrap_function.function_from_graph_def(output_graph_def, new_input_names, new_output_names) # Manually propagate shape for input tensors where the shape is not correctly # propagated. Scalars shapes are lost when wrapping the function. for input_tensor in new_func.inputs: input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) return new_func
def testDiscard(self): a = object() b = object() set1 = object_identity.ObjectIdentitySet([a, b]) set1.discard(a) self.assertIn(b, set1) self.assertNotIn(a, set1)
def test_model(self): model = build_fba_matting() model.compile( optimizer='sgd', loss=['mse', None, None, None], run_eagerly=test_utils.should_run_eagerly()) model.fit( [ np.random.random((2, 240, 240, 3)).astype(np.uint8), np.random.random((2, 240, 240, 2)).astype(np.uint8), np.random.random((2, 240, 240, 6)).astype(np.uint8), ], [ np.random.random((2, 240, 240, 7)).astype(np.float32), np.random.random((2, 240, 240, 1)).astype(np.float32), np.random.random((2, 240, 240, 3)).astype(np.float32), np.random.random((2, 240, 240, 3)).astype(np.float32) ], epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def extract_outputs_from_subclassing_model(model, output_dict, input_names, output_names, input_sigature): from tensorflow.python.keras.saving import saving_utils as _saving_utils from tensorflow.python.util import object_identity from ._graph_cvt import convert_variables_to_constants_v2 as _convert_to_constants function = _saving_utils.trace_model_call(model, input_sigature) concrete_func = function.get_concrete_function() for k_, v_ in concrete_func.structured_outputs.items(): output_names.extend([ts_.name for ts_ in v_.op.outputs]) output_dict.update( build_layer_outputs(model, concrete_func.graph, concrete_func.outputs)) graph_def, converted_input_indices = _convert_to_constants( concrete_func, lower_control_flow=True) input_tensors = concrete_func.graph.internal_captures converted_inputs = object_identity.ObjectIdentitySet( [input_tensors[index] for index in converted_input_indices]) input_names.extend([ tensor.name for tensor in concrete_func.inputs if tensor not in converted_inputs ]) with tf.Graph().as_default() as tf_graph: tf.import_graph_def(graph_def, name='') return tf_graph
def validate_and_slice_inputs(names_to_saveables): """Returns the variables and names that will be used for a Saver. Args: names_to_saveables: A dict (k, v) where k is the name of an operation and v is an operation to save or a BaseSaverBuilder.Saver. Returns: A list of SaveableObjects. Raises: TypeError: If any of the keys are not strings or any of the values are not one of Tensor or Variable or a trackable operation. ValueError: If the same operation is given in more than one value (this also applies to slices of SlicedVariables). """ if not isinstance(names_to_saveables, dict): names_to_saveables = op_list_to_dict(names_to_saveables) saveables = [] seen_ops = object_identity.ObjectIdentitySet() for name, op in sorted( names_to_saveables.items(), # Avoid comparing ops, sort only by name. key=lambda x: x[0]): for converted_saveable_object in saveable_objects_for_op(op, name): _add_saveable(saveables, seen_ops, converted_saveable_object) return saveables
def get_read_only_resource_input_indices_graph(func_graph): """Returns sorted list of read-only resource indices in func_graph.inputs.""" result = [] # A cache to store the read only resource inputs of an Op. # Operation -> ObjectIdentitySet of resource handles. op_read_only_resource_inputs = {} for input_index, t in enumerate(func_graph.inputs): if t.dtype != dtypes.resource: continue read_only = True for op in t.consumers(): if op in op_read_only_resource_inputs: if t not in op_read_only_resource_inputs[op]: read_only = False break else: indices = _get_read_only_resource_input_indices_op(op) op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet( [op.inputs[i] for i in indices]) if t not in op_read_only_resource_inputs[op]: read_only = False break if read_only: result.append(input_index) return result
def test_model(self): num_classes = 5 model = build_deeplab_v3_plus(classes=num_classes, bone_arch='resnet_50', bone_init='imagenet', bone_train=False, aspp_filters=8, aspp_stride=16, low_filters=16, decoder_filters=4) model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', run_eagerly=test_utils.should_run_eagerly()) model.fit(np.random.random((2, 224, 224, 3)).astype(np.uint8), np.random.randint(0, num_classes, (2, 224, 224)), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet( trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def _Inputs(op, xs): """Returns the inputs of op, crossing closure boundaries where necessary. Args: op: Operation xs: list of Tensors we are differentiating w.r.t. Returns: A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op is in a FuncGraph and has captured inputs. """ tensors = object_identity.ObjectIdentitySet(xs) if _IsFunction(op.graph): # pylint: disable=protected-access inputs = [] for t in op.inputs: # If we're differentiating w.r.t. `t`, do not attempt to traverse through # it to a captured value. The algorithm needs to "see" `t` in this case, # even if it's a function input for a captured value, whereas usually we'd # like to traverse through these closures as if the captured value was the # direct input to op. if t not in tensors: t = _MaybeCaptured(t) inputs.append(t) return inputs else: return op.inputs
def __init__(self, record_initial_resource_uses=False, record_uses_of_resource_ids=None): self._returned_tensors = object_identity.ObjectIdentitySet() self.ops_which_must_run = set() self.record_initial_resource_uses = record_initial_resource_uses self.record_uses_of_resource_ids = record_uses_of_resource_ids
def _add_checkpoint_values_check(self, trackable_objects, object_graph_proto): """Determines which objects have checkpoint values and saves to the proto. Args: trackable_objects: A list of all trackable objects. object_graph_proto: A `TrackableObjectGraph` proto. """ # Trackable -> set of all trackables that depend on it (the "parents"). # If a trackable has checkpoint values, then all of the parents can be # marked as having checkpoint values. parents = object_identity.ObjectIdentityDictionary() checkpointed_trackables = object_identity.ObjectIdentitySet() # First pass: build dictionary of parent objects and initial set of # checkpointed trackables. for trackable, object_proto in zip(trackable_objects, object_graph_proto.nodes): if (object_proto.attributes or object_proto.slot_variables or object_proto.HasField("registered_saver")): checkpointed_trackables.add(trackable) for child_proto in object_proto.children: child = trackable_objects[child_proto.node_id] if child not in parents: parents[child] = object_identity.ObjectIdentitySet() parents[child].add(trackable) # Second pass: add all connected parents to set of checkpointed trackables. to_visit = object_identity.ObjectIdentitySet() to_visit.update(checkpointed_trackables) while to_visit: trackable = to_visit.pop() if trackable not in parents: # Some trackables may not have parents (e.g. slot variables). continue current_parents = parents.pop(trackable) checkpointed_trackables.update(current_parents) for parent in current_parents: if parent in parents: to_visit.add(parent) for node_id, trackable in enumerate(trackable_objects): object_graph_proto.nodes[ node_id].has_checkpoint_values.value = bool( trackable in checkpointed_trackables)
def _get_feeds(self, unfed_input_keys: Iterable[str]) -> Iterable[tf.Tensor]: """Returns set of tensors that will be fed.""" result = object_identity.ObjectIdentitySet(self._func_graph.inputs) for input_key in unfed_input_keys: unfed_input_components = _get_component_tensors( self._structured_inputs[input_key]) result = result.difference(unfed_input_components) return result
def test_weights_equals_deduplicated_parameter_dict(model): """ Checks GPflux's `model.trainable_weights` elements equals deduplicated GPflow's `gpflow.utilities.parameter_dict(model)`. """ # We filter out the parameters of type ResourceVariable. # They have been added to the model by the `add_metric` call in the layer. parameters = [ p for p in parameter_dict(model).values() if not isinstance(p, ResourceVariable) ] variables = map(lambda p: p.unconstrained_variable, parameters) deduplicate_variables = object_identity.ObjectIdentitySet(variables) weights = model.trainable_weights assert len(weights) == len(deduplicate_variables) weights_set = object_identity.ObjectIdentitySet(weights) assert weights_set == deduplicate_variables
def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, op_outputs, add_sources): """Walk a Graph and capture the subgraph between init_tensor and sources. Note: This function mutates visited_ops and op_outputs. Arguments: init_tensor: A Tensor or Operation where the subgraph terminates. sources: A set of Tensors where subgraph extraction should stop. disallowed_placeholders: An optional set of ops which may not appear in the lifted graph. Defaults to all placeholders. visited_ops: A set of operations which were visited in a prior pass. op_outputs: A defaultdict containing the outputs of an op which are to be copied into the new subgraph. add_sources: A boolean indicating whether placeholders which are not in sources should be allowed. Returns: The set of placeholders upon which init_tensor depends and are not in sources. Raises: UnliftableError: if init_tensor depends on a placeholder which is not in sources and add_sources is False. """ ops_to_visit = [_as_operation(init_tensor)] extra_sources = object_identity.ObjectIdentitySet() while ops_to_visit: op = ops_to_visit.pop() if op in visited_ops: continue visited_ops.add(op) should_raise = False if disallowed_placeholders is not None and op in disallowed_placeholders: should_raise = True elif op.type == "Placeholder": if disallowed_placeholders is None and not add_sources: should_raise = True extra_sources.update(op.outputs) if should_raise: raise UnliftableError( "Unable to lift tensor %s because it depends transitively on " "placeholder %s via at least one path, e.g.: %s" % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) for inp in graph_inputs(op): op_outputs[inp].add(op) if inp not in visited_ops and inp not in (sources or extra_sources): ops_to_visit.append(inp) return extra_sources
def _get_feeds(self, unfed_input_keys): """Returns set of tensors that will be fed.""" if self._is_finalized: return self._feeds result = object_identity.ObjectIdentitySet(self._func_graph.inputs) for input_key in unfed_input_keys: unfed_input_components = self._get_component_tensors( self._structured_inputs[input_key]) result = result.difference(unfed_input_components) return result
def _add_elements_to_collection(elements, collection_list): if context.executing_eagerly(): raise RuntimeError( 'Using collections from Layers not supported in Eager ' 'mode. Tried to add %s to %s' % (elements, collection_list)) elements = nest.flatten(elements) collection_list = nest.flatten(collection_list) for name in collection_list: collection = ops.get_collection_ref(name) collection_set = object_identity.ObjectIdentitySet(collection) for element in elements: if element not in collection_set: collection.append(element)
def count_params(weights): """Count the total number of scalars composing the weights. Arguments: weights: An iterable containing the weights on which to compute params Returns: The total number of scalars composing the weights """ return int( sum( np.prod(p.shape.as_list()) for p in object_identity.ObjectIdentitySet(weights)))
def count_params(weights): """Count the total number of scalars composing the weights. Arguments: weights: An iterable containing the weights on which to compute params Returns: The total number of scalars composing the weights """ unique_weights = object_identity.ObjectIdentitySet(weights) weight_shapes = [w.shape.as_list() for w in unique_weights] standardized_weight_shapes = [[0 if w_i is None else w_i for w_i in w] for w in weight_shapes] return int(sum(np.prod(p) for p in standardized_weight_shapes))
def count_trainable_params(model): """ Count the number of trainable parameters of tf.keras model Args model: tf.keras model return Total number ot trainable parameters """ weights = model.trainable_weights total_trainable_params = int( sum( np.prod(p.shape) for p in object_identity.ObjectIdentitySet(weights))) return total_trainable_params
def get_read_write_resource_inputs(op): """Returns a tuple of resource reads, writes in op.inputs. Args: op: Operation Returns: A 2-tuple of ObjectIdentitySets, the first entry containing read-only resource handles and the second containing read-write resource handles in `op.inputs`. """ reads = object_identity.ObjectIdentitySet() writes = object_identity.ObjectIdentitySet() if op.type in RESOURCE_READ_OPS: # Add all resource inputs to `reads` and return. reads.update(t for t in op.inputs if t.dtype == dtypes.resource) return (reads, writes) try: read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) except ValueError: # Attr was not set. Add all resource inputs to `writes` and return. writes.update(t for t in op.inputs if t.dtype == dtypes.resource) return (reads, writes) read_only_index = 0 for i, t in enumerate(op.inputs): if op.inputs[i].dtype != dtypes.resource: continue if (read_only_index < len(read_only_input_indices) and i == read_only_input_indices[read_only_index]): reads.add(op.inputs[i]) read_only_index += 1 else: writes.add(op.inputs[i]) return (reads, writes)
def get_dependent_input_output_keys(self, input_keys, exclude_output_keys): """Determine inputs needed to get outputs excluding exclude_output_keys. Args: input_keys: A collection of all input keys available to supply to the SavedModel. exclude_output_keys: A collection of output keys returned by the SavedModel that should be excluded. Returns: A pair of: required_input_keys: A subset of the input features to this SavedModel that are required to compute the set of output features excluding `exclude_output_keys`. It is sorted to be deterministic. output_keys: The set of output features excluding `exclude_output_keys`. It is sorted to be deterministic. """ # Assert inputs being fed and outputs being excluded are part of the # SavedModel. if set(input_keys).difference(self._structured_inputs.keys()): raise ValueError( 'Input tensor names contained tensors not in graph: {}'.format( input_keys)) if set(exclude_output_keys).difference( self._structured_outputs.keys()): raise ValueError( 'Excluded outputs contained keys not in graph: {}'.format( exclude_output_keys)) output_keys = (set( self._structured_outputs.keys()).difference(exclude_output_keys)) # Get all the input tensors that are required to evaluate output_keys. required_inputs = object_identity.ObjectIdentitySet() for key in output_keys: required_inputs.update(self._output_to_inputs_map[key]) # Get all the input feature names that have atleast one component tensor in # required_inputs. required_input_keys = [] for key, tensor in six.iteritems(self._structured_inputs): if any(x in required_inputs for x in self._get_component_tensors(tensor)): required_input_keys.append(key) return sorted(required_input_keys), sorted(output_keys)
def test_model(self): model = build_uper_net(classes=2, bone_arch='swin_tiny_224', bone_init='imagenet', bone_train=False) model.compile(optimizer='sgd', loss='mse', run_eagerly=test_utils.should_run_eagerly()) model.fit( np.random.random((2, 240, 240, 3)).astype(np.uint8), np.random.random((2, 240, 240, 2)).astype(np.float32), epochs=1, batch_size=10) # test config model.get_config() # check whether the model variables are present # in the trackable list of objects checkpointed_objects = object_identity.ObjectIdentitySet(trackable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects)
def _get_resource_inputs(op): """Returns an iterable of resources touched by this `op`.""" resource_inputs = object_identity.ObjectIdentitySet( t for t in op.inputs if t.dtype == dtypes_module.resource) saturated = False while not saturated: saturated = True for key in _acd_resource_resolvers_registry.list(): # Resolvers should return true if they are updating the list of # resource_inputs. # TODO(srbs): An alternate would be to just compare the old and new set # but that may not be as fast. updated = _acd_resource_resolvers_registry.lookup(key)(op, resource_inputs) saturated = saturated and not updated return resource_inputs
def filter_empty_layer_containers(layer_list): """Filter out empty Layer-like containers and uniquify.""" # TODO(b/130381733): Make this an attribute in base_layer.Layer. existing = object_identity.ObjectIdentitySet() to_visit = layer_list[::-1] while to_visit: obj = to_visit.pop() if obj in existing: continue existing.add(obj) if is_layer(obj): yield obj else: sub_layers = getattr(obj, "layers", None) or [] # Trackable data structures will not show up in ".layers" lists, but # the layers they contain will. to_visit.extend(sub_layers[::-1])
def filter_empty_layer_containers(layer_list): """Filter out empty Layer-like containers and uniquify.""" # TODO(b/130381733): Make this an attribute in base_layer.Layer. existing = object_identity.ObjectIdentitySet() to_visit = layer_list[::-1] filtered = [] while to_visit: obj = to_visit.pop() if obj in existing: continue existing.add(obj) if is_layer(obj): filtered.append(obj) elif hasattr(obj, "layers"): # Trackable data structures will not show up in ".layers" lists, but # the layers they contain will. to_visit.extend(obj.layers[::-1]) return filtered