def variables_to_save(graph: tf.Graph) -> typing.Sequence: vars_to_save = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO vars_to_save += graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) vars_to_save += graph.get_collection(tf.GraphKeys.MODEL_VARIABLES) vars_to_save += graph.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES) unique_vars = {} for item in vars_to_save: unique_vars[item.name] = item vars_to_save = list(unique_vars.values()) return vars_to_save
def get_nucleotide_names_from_collections(graph: tf.Graph, collection_list: list) -> list: """ Get all nucleotide names from graph collection list Parameters ---------- graph tensorflow graph collection_list list of collections to search Returns ------- list_of_nucleotide_names list of nucleotide names inside of collection_list """ graph = _maybe_get_default_graph(graph) node_names = [] for collection_name in collection_list: collection = graph.get_collection(collection_name) for each_item in collection: if isinstance(each_item, dict): node_names.extend( [node.name.split(':')[0] for node in each_item.values()]) elif isinstance(each_item, (list, tuple)): node_names.extend( [node.name.split(':')[0] for node in each_item]) else: node_names.append(each_item.name.split(':')[0]) return node_names
def get_asset_annotations(graph: tf.Graph): """Obtains the asset annotations in the specified graph. Args: graph: A `tf.Graph` object. Returns: A dict that maps asset_keys to asset_filenames. Note that if multiple entries for the same key exist, later ones will override earlier ones. """ asset_key_collection = graph.get_collection(_ASSET_KEY_COLLECTION) asset_filename_collection = graph.get_collection( _ASSET_FILENAME_COLLECTION) assert len(asset_key_collection) == len( asset_filename_collection ), 'Length of asset key and filename collections must match.' # Remove scope. annotations = { os.path.basename(key): os.path.basename(filename) for key, filename in zip(asset_key_collection, asset_filename_collection) } return annotations
def _make_collection_defs( tf_g: tf.Graph) -> Iterable[tf.MetaGraphDef.CollectionDefEntry]: """ Convenience function to serialize all the collections in a TensorFlow graph. **NOTE:** Currently this function only captures collections of variables. Args: tf_g: TensorFlow graph from which to harvest collections Returns a list of `tf.MetaGraphDef.CollectionDefEntry` protobuf containing the serialized contents of the collections. """ ret = [] for collection_name in tf_g.collections: if type(collection_name) is not str: print("Skipping non-string collection name {}".format( collection_name)) continue collection_items = tf_g.get_collection(collection_name) collection_proto = tf.MetaGraphDef.CollectionDefEntry() collection_proto.key = collection_name for item in collection_items: if isinstance(item, tf.Variable): # Ask TensorFlow to generate the protobuf version of this variable var_proto = item.to_proto() # TensorFlow stores variables as binary serialized objects for some # reason. collection_proto.value.bytes_list.value.append( var_proto.SerializeToString()) elif type(item).__name__ == "WhileContext": # TODO(frreiss): Should we serialize WhileContexts? print("Skipping collection {} -- is WhileContext.".format( collection_name)) elif type(item).__name__ == "CondContext": # TODO(frreiss): Should we serialize CondContexts? print("Skipping collection {} -- is CondContext.".format( collection_name)) else: raise NotImplementedError( "Can't serialize item '{}' in collection " "'{}' because it is a " "'{}'.".format(item, collection_name, type(item).__name__)) ret.append(collection_proto) return ret
def count_parameters(graph: tf.Graph): total_parameters = 0 trainable = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) parameters = list() for variable in trainable: shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value parameters.append((variable, variable_parameters)) total_parameters += variable_parameters by_scope = Counter() by_name = Counter() for variable, num_params in parameters: by_name[variable.op.name] += num_params by_scope[variable.op.name.split('/')[0]] += num_params return total_parameters, by_scope, by_name
def get_analyzers_fingerprint( graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType] ) -> Mapping[str, AnalyzersFingerprint]: """Computes fingerprints for all analyzers in `graph`. Args: graph: a TF Graph. structured_inputs: a dict from keys to batches of placeholder graph tensors. Returns: A mapping from analyzer name to a set of paths that define its fingerprint. """ result = {} tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS) # The value for the keys in this dictionary are unused and can be arbitrary. sink_tensors_ready = { tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False for tensor_sink in tensor_sinks } graph_analyzer = InitializableGraphAnalyzer( graph, structured_inputs, list(sink_tensors_ready.items()), describe_path_as_analyzer_cache_hash) for tensor_sink in tensor_sinks: # Retrieve tensors that are inputs to the analyzer's value node. visitor = SourcedTensorsVisitor() nodes.Traverser(visitor).visit_value_node(tensor_sink.future) source_keys = _retrieve_source_keys(visitor.sourced_tensors, structured_inputs) paths = set() for tensor in visitor.sourced_tensors: # Obtain fingerprint for each tensor that is an input to the value node. path = graph_analyzer.get_unique_path(tensor) if path is not None: paths.add(path) result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint( source_keys, paths) return result
def get_train_ops(graph: tf.Graph): return graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)