Exemplo n.º 1
0
def variables_to_save(graph: tf.Graph) -> typing.Sequence:
    vars_to_save = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    # TODO
    vars_to_save += graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

    vars_to_save += graph.get_collection(tf.GraphKeys.MODEL_VARIABLES)
    vars_to_save += graph.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)
    unique_vars = {}
    for item in vars_to_save:
        unique_vars[item.name] = item
    vars_to_save = list(unique_vars.values())

    return vars_to_save
Exemplo n.º 2
0
def get_nucleotide_names_from_collections(graph: tf.Graph,
                                          collection_list: list) -> list:
    """
    Get all nucleotide names from graph collection list

    Parameters
    ----------
    graph
        tensorflow graph
    collection_list
        list of collections to search

    Returns
    -------
    list_of_nucleotide_names
        list of nucleotide names inside of collection_list

    """
    graph = _maybe_get_default_graph(graph)
    node_names = []
    for collection_name in collection_list:
        collection = graph.get_collection(collection_name)
        for each_item in collection:
            if isinstance(each_item, dict):
                node_names.extend(
                    [node.name.split(':')[0] for node in each_item.values()])
            elif isinstance(each_item, (list, tuple)):
                node_names.extend(
                    [node.name.split(':')[0] for node in each_item])
            else:
                node_names.append(each_item.name.split(':')[0])
    return node_names
Exemplo n.º 3
0
def get_asset_annotations(graph: tf.Graph):
    """Obtains the asset annotations in the specified graph.

  Args:
    graph: A `tf.Graph` object.

  Returns:
    A dict that maps asset_keys to asset_filenames. Note that if multiple
    entries for the same key exist, later ones will override earlier ones.
  """
    asset_key_collection = graph.get_collection(_ASSET_KEY_COLLECTION)
    asset_filename_collection = graph.get_collection(
        _ASSET_FILENAME_COLLECTION)
    assert len(asset_key_collection) == len(
        asset_filename_collection
    ), 'Length of asset key and filename collections must match.'
    # Remove scope.
    annotations = {
        os.path.basename(key): os.path.basename(filename)
        for key, filename in zip(asset_key_collection,
                                 asset_filename_collection)
    }
    return annotations
Exemplo n.º 4
0
def _make_collection_defs(
        tf_g: tf.Graph) -> Iterable[tf.MetaGraphDef.CollectionDefEntry]:
    """
  Convenience function to serialize all the collections in a TensorFlow graph.

  **NOTE:** Currently this function only captures collections of variables.

  Args:
    tf_g: TensorFlow graph from which to harvest collections

  Returns a list of `tf.MetaGraphDef.CollectionDefEntry` protobuf containing
  the serialized
  contents of the collections.
  """
    ret = []
    for collection_name in tf_g.collections:
        if type(collection_name) is not str:
            print("Skipping non-string collection name {}".format(
                collection_name))
            continue
        collection_items = tf_g.get_collection(collection_name)
        collection_proto = tf.MetaGraphDef.CollectionDefEntry()
        collection_proto.key = collection_name
        for item in collection_items:
            if isinstance(item, tf.Variable):
                # Ask TensorFlow to generate the protobuf version of this variable
                var_proto = item.to_proto()

                # TensorFlow stores variables as binary serialized objects for some
                # reason.
                collection_proto.value.bytes_list.value.append(
                    var_proto.SerializeToString())
            elif type(item).__name__ == "WhileContext":
                # TODO(frreiss): Should we serialize WhileContexts?
                print("Skipping collection {} -- is WhileContext.".format(
                    collection_name))
            elif type(item).__name__ == "CondContext":
                # TODO(frreiss): Should we serialize CondContexts?
                print("Skipping collection {} -- is CondContext.".format(
                    collection_name))
            else:
                raise NotImplementedError(
                    "Can't serialize item '{}' in collection "
                    "'{}' because it is a "
                    "'{}'.".format(item, collection_name,
                                   type(item).__name__))

        ret.append(collection_proto)
    return ret
Exemplo n.º 5
0
def count_parameters(graph: tf.Graph):
    total_parameters = 0

    trainable = graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

    parameters = list()

    for variable in trainable:
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        parameters.append((variable, variable_parameters))
        total_parameters += variable_parameters

    by_scope = Counter()
    by_name = Counter()
    for variable, num_params in parameters:
        by_name[variable.op.name] += num_params
        by_scope[variable.op.name.split('/')[0]] += num_params

    return total_parameters, by_scope, by_name
Exemplo n.º 6
0
def get_analyzers_fingerprint(
    graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]
) -> Mapping[str, AnalyzersFingerprint]:
  """Computes fingerprints for all analyzers in `graph`.

  Args:
    graph: a TF Graph.
    structured_inputs: a dict from keys to batches of placeholder graph tensors.

  Returns:
    A mapping from analyzer name to a set of paths that define its fingerprint.
  """
  result = {}
  tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS)
  # The value for the keys in this dictionary are unused and can be arbitrary.
  sink_tensors_ready = {
      tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False
      for tensor_sink in tensor_sinks
  }
  graph_analyzer = InitializableGraphAnalyzer(
      graph, structured_inputs, list(sink_tensors_ready.items()),
      describe_path_as_analyzer_cache_hash)
  for tensor_sink in tensor_sinks:
    # Retrieve tensors that are inputs to the analyzer's value node.
    visitor = SourcedTensorsVisitor()
    nodes.Traverser(visitor).visit_value_node(tensor_sink.future)
    source_keys = _retrieve_source_keys(visitor.sourced_tensors,
                                        structured_inputs)
    paths = set()
    for tensor in visitor.sourced_tensors:
      # Obtain fingerprint for each tensor that is an input to the value node.
      path = graph_analyzer.get_unique_path(tensor)
      if path is not None:
        paths.add(path)
    result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint(
        source_keys, paths)
  return result
Exemplo n.º 7
0
def get_train_ops(graph: tf.Graph):
    return graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)