Ejemplo n.º 1
0
def summarize_graph(graph_def):
    placeholders = dict()
    outputs = list()
    graph = tf_v1.Graph()
    with graph.as_default():  # pylint: disable=not-context-manager
        tf_v1.import_graph_def(graph_def, name='')
    for node in graph.as_graph_def().node:  # pylint: disable=no-member
        if node.op == 'Placeholder':
            node_dict = dict()
            node_dict['type'] = tf_v1.DType(node.attr['dtype'].type).name
            node_dict['shape'] = str(tf_v1.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1')
            placeholders[node.name] = node_dict
        if len(children(node.name, graph)) == 0:
            if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types:
                outputs.append(node.name)
    result = dict()
    result['inputs'] = placeholders
    result['outputs'] = outputs
    return result
Ejemplo n.º 2
0
def attr_value_to_python_type(
        attr_value,  # type: tf.AttrValue
        attr_name  # type: String
):
    # type (...) -> Any
    """
  Inverse of python_type_to_attr_value().

  Args:
    attr_value: Protocol buffer version of a node's attribute value

  Returns:
    A Python object or built-in type corresponding to the field in
    `attr_value` that is in use.
  """
    # TODO(frreiss): Handle AttrValues that are lists
    if attr_value.HasField("s"):  # str
        # TODO(frreiss): Should we return the binary value here?
        return tf.compat.as_str(attr_value.s)
    elif attr_value.HasField("i"):  # int
        return attr_value.i
    elif attr_value.HasField("f"):  # float
        return attr_value.f
    elif attr_value.HasField("b"):  # bool
        return attr_value.b
    elif attr_value.HasField("type"):  # DType
        return tf.DType(attr_value.type)
    elif attr_value.HasField("shape"):  # TensorShape
        # Undocumented behavior of public API: tf.TensorShape constructor accepts
        # a TensorShapeProto.
        return tf.TensorShape(attr_value.shape)
    elif attr_value.HasField("tensor"):  # TensorProto
        return tf.make_ndarray(attr_value.tensor)
    elif attr_value.HasField("list"):  # list
        return attr_value.list
    elif attr_value.HasField("func"):  # func
        return attr_value.func
    # TODO(frreiss): Convert the "placeholder" fields of the union  here
    else:
        raise ValueError("Don't know how to convert AttrValue {} to "
                         "a Python object for attribute {}".format(
                             attr_value, attr_name))