def _InitOpDefLibrary(): op_list = _op_def_pb2.OpList() _text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list) _op_def_registry.register_op_list(op_list) op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib
def _InitOpDefLibrary(op_list_proto_bytes): op_list = _op_def_pb2.OpList() op_list.ParseFromString(op_list_proto_bytes) _op_def_registry.register_op_list(op_list) op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. When the library is loaded, ops and kernels registered in the library via the `REGISTER_*` macros are made available in the TensorFlow process. Note that ops with the same name as an existing op are rejected and not registered with the process. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) # pylint: disable=protected-access raise errors_impl._make_specific_exception(None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str) # Delete the library handle to release any memory held in C # that are no longer needed. py_tf.TF_DeleteLibraryHandle(lib_handle) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() if module_name in sys.modules: return sys.modules[module_name] module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module return module
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be defined in the library. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) with _OP_LIBRARY_MAP_LOCK: if (error_code == error_codes_pb2.ALREADY_EXISTS and 'has already been loaded' in error_msg and library_filename in _OP_LIBRARY_MAP): return _OP_LIBRARY_MAP[library_filename] # pylint: disable=protected-access raise errors._make_specific_exception(None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str)) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module # Memoize the filename to module mapping. with _OP_LIBRARY_MAP_LOCK: _OP_LIBRARY_MAP[library_filename] = module return module
def get_ops(): sdk_path = os.path.dirname(__file__) ops_file = os.path.join(sdk_path, 'mmdnn_ops.pbtxt') ops = op_def_pb2.OpList() with open(ops_file) as fn: Merge(fn.read(), ops) ops_list = [MessageToDict(op) for op in ops.op] return ops_list
def register_ops_if_needed(graph_ops): """Register graph ops absent in op_def_registry, if present in c++ registry. Args: graph_ops: set with graph op names to register. Raises: RuntimeError: if `graph_ops` contains ops that are not in either python or c++ registry. """ missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys()) if not missing_ops: return p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) cpp_registry_ops = {op.name: op for op in cpp_op_list.op} missing_op_list = op_def_pb2.OpList() for missing_op in missing_ops: if missing_op not in cpp_registry_ops: tf.logging.info( "Op %s is missing from both the python and C++ registry.", missing_op) else: missing_op_list.op.extend([cpp_registry_ops[missing_op]]) tf.logging.info( "Adding op %s from c++ registry to python registry.", missing_op) op_def_registry.register_op_list(missing_op_list) # Note: Only raise missing op ValueError after trying to load ops. # This allows the test to exercise all the calls into TensorFlow # without having to write a C + python test. if not missing_ops <= set(cpp_registry_ops.keys()): raise RuntimeError( "Graph ops missing from the python registry (%s) are also absent from " "the c++ registry." % missing_ops.difference(set(cpp_registry_ops.keys())))
def register_prelu_op(): """Register a virtual PReLU OpDef. This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0. """ prelu_op_def = op_def_pb2.OpDef() prelu_op_def.name = 'Prelu' missing_op_list = op_def_pb2.OpList() missing_op_list.op.extend([prelu_op_def]) op_def_registry.register_op_list(missing_op_list)
def __init__(self): op_def_proto = op_def_pb2.OpList() buf = c_api.TF_GetAllOpList() try: op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) self._api_def_map = c_api.TF_NewApiDefMap(buf) finally: c_api.TF_DeleteBuffer(buf) self._op_per_name = {} for op in op_def_proto.op: self._op_per_name[op.name] = op
def sync(): """Synchronize the contents of the Python registry with C++.""" with _sync_lock: p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) for op_def in cpp_op_list.op: # If an OpList is registered from a gen_*_ops.py, it does not any # descriptions. Strip them here as well to satisfy validation in # register_op_list. _remove_non_deprecated_descriptions(op_def) _registered_ops[op_def.name] = op_def
def sync(): p_buffer = c_api.TF_GetAllOpList() cpp_op_list = op_def_pb2.OpList() cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer)) registered_ops = op_def_registry.get_registered_ops() for op_def in cpp_op_list.op: # If an OpList is registered from a gen_*_ops.py, it does not any # descriptions. Strip them here as well to satisfy validation in # register_op_list. _remove_non_deprecated_descriptions(op_def) registered_ops[op_def.name] = op_def
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. When the library is loaded, ops and kernels registered in the library via the `REGISTER_*` macros are made available in the TensorFlow process. Note that ops with the same name as an existing op are rejected and not registered with the process. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ lib_handle = py_tf.TF_LoadLibrary(library_filename) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str) # Delete the library handle to release any memory held in C # that are no longer needed. py_tf.TF_DeleteLibraryHandle(lib_handle) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() if module_name in sys.modules: return sys.modules[module_name] module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list # Allow this to be recognized by AutoGraph. setattr(module, '_IS_TENSORFLOW_PLUGIN', True) sys.modules[module_name] = module return module
def register_prelu_op(): """global registry of PReLU op for python, this allow metagraph to be properly generated with unregistered Prelu op """ value = attr_value_pb2.AttrValue() value.list.type.extend([types_pb2.DT_FLOAT]) attr = op_def_pb2.OpDef.AttrDef() attr.name = 'T' attr.type = 'type' attr.allowed_values.CopyFrom(value) prelu_op_def = op_def_pb2.OpDef() prelu_op_def.name = 'Prelu' prelu_op_def.attr.extend([attr]) missing_op_list = op_def_pb2.OpList() missing_op_list.op.extend([prelu_op_def]) op_def_registry.register_op_list(missing_op_list)
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be defined in the library. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: if py_tf.TF_GetCode(status) != 0: raise RuntimeError(compat.as_text(py_tf.TF_Message(status))) finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str)) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module return module
def _create_op_def_library(op_protos): for op_proto in op_protos: registered_ops = _registry.get_registered_ops() if op_proto.name not in registered_ops: raise LookupError("Op with name {0} not registered".format( op_proto.name)) op_def_lib = _op_def_library.OpDefLibrary() ops_proto = _op_def_pb2.OpList() ops_proto.op.extend([op_proto]) # Fails if the interfaces ("op schemas") don't match between the # previously registered op and this one. _registry.register_op_list(ops_proto) op_def_lib.add_op_list(ops_proto) return op_def_lib
def register_prelu_op(): """Register a virtual PReLU OpDef. This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0. """ value = attr_value_pb2.AttrValue() value.list.type.extend([types_pb2.DT_FLOAT]) attr = op_def_pb2.OpDef.AttrDef() attr.name = 'T' attr.type = 'type' attr.allowed_values.CopyFrom(value) prelu_op_def = op_def_pb2.OpDef() prelu_op_def.name = 'Prelu' prelu_op_def.attr.extend([attr]) missing_op_list = op_def_pb2.OpList() missing_op_list.op.extend([prelu_op_def]) op_def_registry.register_op_list(missing_op_list)
def testDefaultAttrsRemoved(self): producer_op_list = op_def_pb2.OpList() text_format.Merge(""" op { name: 'OpWithFutureDefaultAttr' attr { name: 'default_int' type: 'int' default_value { i: 456 } } } """, producer_op_list) # Attr only in producer_op_list with default value gets removed. with ops.Graph().as_default(): a = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'OpWithFutureDefaultAttr' attr { key: 'default_int' value { i: 456 } } } """), return_elements=["A"], producer_op_list=producer_op_list) with self.assertRaisesRegexp( ValueError, "Operation 'import/A' has no attr named 'default_int'."): a[0].get_attr("default_int")
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. Raises: ValueError: If an unregistered op is used. """ # This is the Python equivalent of StrippedOpListForGraph in C++. # Unfortunately, since the Python op registry can differ from that in C++, we # can't remove the duplication using swig (at least naively). # TODO(irving): Support taking graphs directly. used_ops = ops_used_by_graph_def(graph_def) # Verify that all used ops are registered. registered_ops = op_def_registry.get_registered_ops() # These internal ops used by functions are not registered, so we need to # whitelist them. # TODO(irving): Do something better here. op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList") for op in used_ops: if op not in registered_ops and op not in op_whitelist: raise ValueError( "Op %s is used by the graph, but is not registered" % op) # Build the stripped op list in sorted order return op_def_pb2.OpList(op=[ registered_ops[op] for op in sorted(used_ops) if op in registered_ops ])
def testDefaultAttrsRemoved(self): producer_op_list = op_def_pb2.OpList() text_format.Merge(""" op { name: 'OpWithFutureDefaultAttr' attr { name: 'default_int' type: 'int' default_value { i: 456 } } } """, producer_op_list) # Attr only in producer_op_list with default value gets removed. with ops.Graph().as_default(): a = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'OpWithFutureDefaultAttr' attr { key: 'default_int' value { i: 456 } } } """), return_elements=["A"], producer_op_list=producer_op_list) if ops._USE_C_API: error_msg = "Operation 'import/A' has no attr named 'default_int'." else: error_msg = "No attr named 'default_int'" with self.assertRaisesRegexp(ValueError, error_msg): a[0].get_attr("default_int") # Unknown attrs cannot be imported using C API. This test will eventually be # deleted. if not ops._USE_C_API: # Attr only in producer_op_list with non-default value is preserved. with ops.Graph().as_default(): a = importer.import_graph_def( self._MakeGraphDef(""" node { name: 'A' op: 'OpWithFutureDefaultAttr' attr { key: 'default_int' value { i: 987 } } } """), return_elements=["A"], producer_op_list=producer_op_list) self.assertEqual(987, a[0].get_attr("default_int"))
def stripped_op_list_for_graph(graph_def): """Collect the stripped OpDefs for ops used by a graph. This function computes the `stripped_op_list` field of `MetaGraphDef` and similar protos. The result can be communicated from the producer to the consumer, which can then use the C++ function `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility. Args: graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`. Returns: An `OpList` of ops used by the graph. """ # This is similar to StrippedOpListForGraph in C++, but unlike its # C++ counterpart, this version does not require all ops to be registered. # This is done to support Prelu fusion in tfjs. used_ops = ops_used_by_graph_def(graph_def) op_defs = [] for op in sorted(used_ops): op_def = op_def_registry.get(op) if op_def is not None: op_defs.append(op_def) return op_def_pb2.OpList(op=op_defs)
import tensorflow as tf from tensorflow.core.framework import op_def_pb2 from google.protobuf import text_format from termcolor import colored import re import textwrap ops = op_def_pb2.OpList() text_format.Merge(open('ops.pbtxt').read(), ops) data_type_name = [ '', 'FLOAT', 'DOUBLE', 'INT32', 'UINT8', 'INT16', 'INT8', 'STRING', 'COMPLEX64', 'COMPLEX', 'INT64', 'BOOL', 'QINT8', 'QUINT8', 'QINT32', 'BFLOAT16', 'QINT16', 'QUINT16', 'UINT16', 'COMPLEX128', 'HALF', 'RESOURCE', 'VARIANT', 'UINT32', 'UINT64' ] lua_keywords = [ 'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function', 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true', 'until', 'while' ] def proc_var_name(name): if name in lua_keywords: name += '_' return name pattern_arg = ' Args:' pattern_ret = ' Returns:'
def metadata(): categories = { 'Assign': 'Control', 'AvgPool': 'Pool', 'BatchNormWithGlobalNormalization': 'Normalization', 'BiasAdd': 'Layer', 'Concat': 'Tensor', 'ConcatV2': 'Tensor', 'Const': 'Constant', 'Conv2D': 'Layer', 'DepthwiseConv2dNative': 'Layer', 'Dequantize': 'Tensor', 'Elu': 'Activation', 'FusedBatchNorm': 'Normalization', 'FusedBatchNormV2': 'Normalization', 'FusedBatchNormV3': 'Normalization', 'Gather': 'Transform', 'Identity': 'Control', 'LeakyRelu': 'Activation', 'LRN': 'Normalization', 'LSTMBlockCell': 'Layer', 'MaxPool': 'Pool', 'MaxPoolV2': 'Pool', 'MaxPoolWithArgmax': 'Pool', 'Pad': 'Tensor', 'Relu': 'Activation', 'Relu6': 'Activation', 'Reshape': 'Shape', 'Sigmoid': 'Activation', 'Slice': 'Tensor', 'Softmax': 'Activation', 'Split': 'Tensor', 'Squeeze': 'Shape', 'StridedSlice': 'Tensor', 'swish_f32': 'Activation', 'Variable': 'Control', 'VariableV2': 'Control', } def find_multiline(line, colon): if colon == -1: return None line = line[colon + 1:] while line.startswith(' '): line = line[1:] if line.startswith('<<'): line = line[2:] return line return None def str_escape(text): result = '' for c in text: if (c == '\n'): result += '\\n' elif (c == '\r'): result += "\\r" elif (c == '\t'): result += "\\t" elif (c == '\"'): result += "\\\"" elif (c == '\''): result += "\\'" elif (c == '\\'): result += "\\\\" else: result += c return result def pbtxt_from_multiline(multiline_pbtxt): pbtxt = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') if index == -1: pbtxt = pbtxt + multiline_pbtxt multiline_pbtxt = '' break line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] colon = line.find(':') end = find_multiline(line, colon) if end == None: pbtxt = pbtxt + line + '\n' continue pbtxt = pbtxt + line[0:colon + 1] unescaped = '' newline = False line = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] if line.startswith(end): line = line[len(end):] break if newline: unescaped = unescaped + '\n' newline = True unescaped = unescaped + line line = '' pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n' return pbtxt def read_api_def_map(folder): api_def_map = {} file_list = os.listdir(folder) file_list = sorted(file_list) for filename in file_list: if filename.endswith('.pbtxt'): api_defs = api_def_pb2.ApiDefs() filename = folder + '/' + filename with open(filename) as handle: multiline_pbtxt = handle.read() pbtxt = pbtxt_from_multiline(multiline_pbtxt) text_format.Merge(pbtxt, api_defs) for api_def in api_defs.op: api_def_map[api_def.graph_op_name] = api_def return api_def_map def convert_type(type): return {'type': 'type', 'value': type} def convert_tensor(tensor): return {'type': 'tensor', 'value': '?'} def convert_shape(shape): return {'type': 'shape', 'value': '?'} def convert_number(number): if number == float('inf'): return 'NaN' if number == float('-inf'): return '-NaN' return number attr_type_table = { 'type': 'type', 'list(type)': 'type[]', 'bool': 'boolean', 'int': 'int64', 'list(int)': 'int64[]', 'float': 'float32', 'list(float)': 'float32[]', 'string': 'string', 'list(string)': 'string[]', 'shape': 'shape', 'list(shape)': 'shape[]', 'tensor': 'tensor', 'func': 'function', 'list(func)': 'function[]' } def convert_attr_type(type): if type in attr_type_table: return attr_type_table[type] print(type) return type def convert_attr_value(attr_value): if attr_value.HasField('list'): list = [] attr_value_list = attr_value.list if len(attr_value_list.s) > 0: for s in attr_value_list.s: list.append(s.decode('utf8')) if len(attr_value_list.i) > 0: for i in attr_value_list.i: list.append(i) if len(attr_value_list.f) > 0: for f in attr_value_list.f: list.append(convert_number(f)) if len(attr_value_list.type) > 0: for type in attr_value_list.type: list.append(convert_type(type)) if len(list) == 0: for _, value in attr_value_list.ListFields(): if len(value) > 0: raise Exception() return list if attr_value.HasField('s'): return attr_value.s.decode('utf8') if attr_value.HasField('i'): return attr_value.i if attr_value.HasField('f'): return convert_number(attr_value.f) if attr_value.HasField('b'): return attr_value.b if attr_value.HasField('type'): return convert_type(attr_value.type) if attr_value.HasField('tensor'): return convert_tensor(attr_value.tensor) if attr_value.HasField('shape'): return convert_shape(attr_value.shape) raise Exception() _TYPE_TO_STRING = { types_pb2.DataType.DT_HALF: "float16", types_pb2.DataType.DT_FLOAT: "float32", types_pb2.DataType.DT_DOUBLE: "float64", types_pb2.DataType.DT_INT32: "int32", types_pb2.DataType.DT_UINT8: "uint8", types_pb2.DataType.DT_UINT16: "uint16", types_pb2.DataType.DT_UINT32: "uint32", types_pb2.DataType.DT_UINT64: "uint64", types_pb2.DataType.DT_INT16: "int16", types_pb2.DataType.DT_INT8: "int8", types_pb2.DataType.DT_STRING: "string", types_pb2.DataType.DT_COMPLEX64: "complex64", types_pb2.DataType.DT_COMPLEX128: "complex128", types_pb2.DataType.DT_INT64: "int64", types_pb2.DataType.DT_BOOL: "bool", types_pb2.DataType.DT_QINT8: "qint8", types_pb2.DataType.DT_QUINT8: "quint8", types_pb2.DataType.DT_QINT16: "qint16", types_pb2.DataType.DT_QUINT16: "quint16", types_pb2.DataType.DT_QINT32: "qint32", types_pb2.DataType.DT_BFLOAT16: "bfloat16", types_pb2.DataType.DT_RESOURCE: "resource", types_pb2.DataType.DT_VARIANT: "variant", types_pb2.DataType.DT_HALF_REF: "float16_ref", types_pb2.DataType.DT_FLOAT_REF: "float32_ref", types_pb2.DataType.DT_DOUBLE_REF: "float64_ref", types_pb2.DataType.DT_INT32_REF: "int32_ref", types_pb2.DataType.DT_UINT32_REF: "uint32_ref", types_pb2.DataType.DT_UINT8_REF: "uint8_ref", types_pb2.DataType.DT_UINT16_REF: "uint16_ref", types_pb2.DataType.DT_INT16_REF: "int16_ref", types_pb2.DataType.DT_INT8_REF: "int8_ref", types_pb2.DataType.DT_STRING_REF: "string_ref", types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref", types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref", types_pb2.DataType.DT_INT64_REF: "int64_ref", types_pb2.DataType.DT_UINT64_REF: "uint64_ref", types_pb2.DataType.DT_BOOL_REF: "bool_ref", types_pb2.DataType.DT_QINT8_REF: "qint8_ref", types_pb2.DataType.DT_QUINT8_REF: "quint8_ref", types_pb2.DataType.DT_QINT16_REF: "qint16_ref", types_pb2.DataType.DT_QUINT16_REF: "quint16_ref", types_pb2.DataType.DT_QINT32_REF: "qint32_ref", types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref", types_pb2.DataType.DT_RESOURCE_REF: "resource_ref", types_pb2.DataType.DT_VARIANT_REF: "variant_ref", } def format_data_type(data_type): if data_type in _TYPE_TO_STRING: return _TYPE_TO_STRING[data_type] raise Exception() def format_attribute_value(value): if type(value ) is dict and 'type' in value and 'value' in value and value[ 'type'] == 'type': return format_data_type(value['value']) if type(value) is str: return value if value == True: return 'true' if value == False: return 'false' raise Exception() tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/source/tensorflow') api_def_map = read_api_def_map( os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api')) input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt') ops_list = op_def_pb2.OpList() with open(input_file) as input_handle: text_format.Merge(input_handle.read(), ops_list) json_root = [] for op in ops_list.op: # print(op.name) json_schema = {} json_schema['name'] = op.name if op.name in categories: json_schema['category'] = categories[op.name] api_def = api_def_pb2.ApiDef() if op.name in api_def_map: api_def = api_def_map[op.name] # if op.deprecation.version != 0: # print('[' + op.name + ']') # print(op.deprecation.version) # print(op.deprecation.explanation) api_def_attr_map = {} for attr in api_def.attr: api_def_attr_map[attr.name] = attr api_def_in_arg_map = {} for in_arg in api_def.in_arg: api_def_in_arg_map[in_arg.name] = in_arg api_def_out_arg_map = {} for out_arg in api_def.out_arg: api_def_out_arg_map[out_arg.name] = out_arg if api_def.summary: json_schema['summary'] = api_def.summary if api_def.description: json_schema['description'] = api_def.description for attr in op.attr: if not 'attributes' in json_schema: json_schema['attributes'] = [] json_attribute = {} json_attribute['name'] = attr.name attr_type = convert_attr_type(attr.type) if attr_type: json_attribute['type'] = attr_type else: del json_attribute['type'] if attr.name in api_def_attr_map: api_def_attr = api_def_attr_map[attr.name] if api_def_attr.description: json_attribute['description'] = api_def_attr.description if attr.has_minimum: json_attribute['minimum'] = attr.minimum if attr.HasField('allowed_values'): allowed_values = convert_attr_value(attr.allowed_values) description = json_attribute[ 'description'] + ' ' if 'description' in json_attribute else '' description = description + 'Must be one of the following: ' + ', '.join( list( map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.' json_attribute['description'] = description if attr.HasField('default_value'): default_value = convert_attr_value(attr.default_value) json_attribute['default'] = default_value json_schema['attributes'].append(json_attribute) for input_arg in op.input_arg: if not 'inputs' in json_schema: json_schema['inputs'] = [] json_input = {} json_input['name'] = input_arg.name if input_arg.name in api_def_in_arg_map: api_def_in_arg = api_def_in_arg_map[input_arg.name] if api_def_in_arg.description: json_input['description'] = api_def_in_arg.description if input_arg.number_attr: json_input['numberAttr'] = input_arg.number_attr if input_arg.type: json_input['type'] = input_arg.type if input_arg.type_attr: json_input['typeAttr'] = input_arg.type_attr if input_arg.type_list_attr: json_input['typeListAttr'] = input_arg.type_list_attr if input_arg.is_ref: json_input['isRef'] = True json_schema['inputs'].append(json_input) for output_arg in op.output_arg: if not 'outputs' in json_schema: json_schema['outputs'] = [] json_output = {} json_output['name'] = output_arg.name if output_arg.name in api_def_out_arg_map: api_def_out_arg = api_def_out_arg_map[output_arg.name] if api_def_out_arg.description: json_output['description'] = api_def_out_arg.description if output_arg.number_attr: json_output['numberAttr'] = output_arg.number_attr if output_arg.type: json_output['type'] = output_arg.type elif output_arg.type_attr: json_output['typeAttr'] = output_arg.type_attr elif output_arg.type_list_attr: json_output['typeListAttr'] = output_arg.type_list_attr if output_arg.is_ref: json_output['isRef'] = True json_schema['outputs'].append(json_output) json_root.append(json_schema) json_file = os.path.join(os.path.dirname(__file__), '../source/tf-metadata.json') with io.open(json_file, 'w', newline='') as fout: json_data = json.dumps(json_root, sort_keys=False, indent=2) for line in json_data.splitlines(): line = line.rstrip() fout.write(line) fout.write('\n')
ops.RegisterShape("Iif")(_UnknownShape) ops.RegisterShape("Iii")(_UnknownShape) ops.RegisterShape("In")(_UnknownShape) ops.RegisterShape("Iri")(_UnknownShape) ops.RegisterShape("None")(_UnknownShape) ops.RegisterShape("Of")(_UnknownShape) ops.RegisterShape("Oi")(_UnknownShape) ops.RegisterShape("Oif")(_UnknownShape) ops.RegisterShape("Oii")(_UnknownShape) ops.RegisterShape("OpWithDefaultAttr")(_UnknownShape) ops.RegisterShape("OpWithFutureDefaultAttr")(_UnknownShape) ops.RegisterShape("Or")(_UnknownShape) ops.RegisterShape("Otl")(_UnknownShape) ops.RegisterShape("Unary")(_UnknownShape) _op_list = op_def_pb2.OpList() text_format.Merge( """ op { name: 'None' } op { name: 'Oi' output_arg { name: 'a' type: DT_INT32 } } op { name: 'Or' output_arg { name: 'a' type: DT_INT32 is_ref: true } } op { name: 'Of'
def metadata(): categories = { 'Const': 'Constant', 'Conv2D': 'Layer', 'BiasAdd': 'Layer', 'DepthwiseConv2dNative': 'Layer', 'Relu': 'Activation', 'Relu6': 'Activation', 'Elu': 'Activation', 'Softmax': 'Activation', 'Sigmoid': 'Activation', 'LRN': 'Normalization', 'MaxPool': 'Pool', 'MaxPoolV2': 'Pool', 'AvgPool': 'Pool', 'Reshape': 'Shape', 'Squeeze': 'Shape', 'ConcatV2': 'Tensor', 'Split': 'Tensor', 'Dequantize': 'Tensor', 'Identity': 'Control', 'Variable': 'Control', 'VariableV2': 'Control', 'Assign': 'Control', 'BatchNormWithGlobalNormalization': 'Normalization', 'FusedBatchNorm': 'Normalization', # 'VariableV2': # 'Assign': # 'BiasAdd': } def find_multiline(line, colon): if colon == -1: return None line = line[colon + 1:] while line.startswith(' '): line = line[1:] if line.startswith('<<'): line = line[2:] return line return None def str_escape(text): result = '' for c in text: if (c == '\n'): result += '\\n' elif (c == '\r'): result += "\\r" elif (c == '\t'): result += "\\t" elif (c == '\"'): result += "\\\"" elif (c == '\''): result += "\\'" elif (c == '\\'): result += "\\\\" else: result += c return result def pbtxt_from_multiline(multiline_pbtxt): pbtxt = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') if index == -1: pbtxt = pbtxt + multiline_pbtxt multiline_pbtxt = '' break line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] colon = line.find(':') end = find_multiline(line, colon) if end == None: pbtxt = pbtxt + line + '\n' continue pbtxt = pbtxt + line[0:colon + 1] unescaped = '' newline = False line = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] if line.startswith(end): line = line[len(end):] break if newline: unescaped = unescaped + '\n' newline = True unescaped = unescaped + line line = '' pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n' return pbtxt def read_api_def_map(folder): api_def_map = {} file_list = os.listdir(folder) file_list = sorted(file_list) for filename in file_list: api_defs = api_def_pb2.ApiDefs() filename = folder + '/' + filename with open(filename) as handle: multiline_pbtxt = handle.read() pbtxt = pbtxt_from_multiline(multiline_pbtxt) text_format.Merge(pbtxt, api_defs) for api_def in api_defs.op: api_def_map[api_def.graph_op_name] = api_def return api_def_map def convert_type(type): return {'type': 'type', 'value': type} def convert_tensor(tensor): return {'type': 'tensor', 'value': '?'} def convert_shape(shape): return {'type': 'shape', 'value': '?'} def convert_number(number): if number == float('inf'): return 'NaN' if number == float('-inf'): return '-NaN' return number attr_type_table = { 'type': 'type', 'list(type)': 'type[]', 'bool': 'boolean', 'int': 'int64', 'list(int)': 'int64[]', 'float': 'float32', 'list(float)': 'float32[]', 'string': 'string', 'list(string)': 'string[]', 'shape': 'shape', 'list(shape)': 'shape[]', 'tensor': 'tensor', 'func': 'function', 'list(func)': 'function[]' } def convert_attr_type(type): if type in attr_type_table: return attr_type_table[type] print(type) return type def convert_attr_value(attr_value): if attr_value.HasField('list'): list = [] attr_value_list = attr_value.list if len(attr_value_list.s) > 0: for s in attr_value_list.s: list.append(s.decode('utf8')) if len(attr_value_list.i) > 0: for i in attr_value_list.i: list.append(i) if len(attr_value_list.f) > 0: for f in attr_value_list.f: list.append(convert_number(f)) if len(attr_value_list.type) > 0: for type in attr_value_list.type: list.append(convert_type(type)) if len(list) == 0: for _, value in attr_value_list.ListFields(): if len(value) > 0: raise Exception() return list if attr_value.HasField('s'): return attr_value.s.decode('utf8') if attr_value.HasField('i'): return attr_value.i if attr_value.HasField('f'): return convert_number(attr_value.f) if attr_value.HasField('b'): return attr_value.b if attr_value.HasField('type'): return convert_type(attr_value.type) if attr_value.HasField('tensor'): return convert_tensor(attr_value.tensor) if attr_value.HasField('shape'): return convert_shape(attr_value.shape) raise Exception() tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/src/tensorflow') api_def_map = read_api_def_map( os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api')) input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt') ops_list = op_def_pb2.OpList() with open(input_file) as input_handle: text_format.Merge(input_handle.read(), ops_list) json_root = [] for op in ops_list.op: # print(op.name) json_schema = {} if op.name in categories: json_schema['category'] = categories[op.name] api_def = api_def_pb2.ApiDef() if op.name in api_def_map: api_def = api_def_map[op.name] # if op.deprecation.version != 0: # print('[' + op.name + ']') # print(op.deprecation.version) # print(op.deprecation.explanation) api_def_attr_map = {} for attr in api_def.attr: api_def_attr_map[attr.name] = attr api_def_in_arg_map = {} for in_arg in api_def.in_arg: api_def_in_arg_map[in_arg.name] = in_arg api_def_out_arg_map = {} for out_arg in api_def.out_arg: api_def_out_arg_map[out_arg.name] = out_arg if api_def.summary: json_schema['summary'] = api_def.summary if api_def.description: json_schema['description'] = api_def.description for attr in op.attr: if not 'attributes' in json_schema: json_schema['attributes'] = [] json_attribute = {} json_attribute['name'] = attr.name attr_type = convert_attr_type(attr.type) if attr_type: json_attribute['type'] = attr_type else: del json_attribute['type'] if attr.name in api_def_attr_map: api_def_attr = api_def_attr_map[attr.name] if api_def_attr.description: json_attribute['description'] = api_def_attr.description if attr.has_minimum: json_attribute['minimum'] = attr.minimum if attr.HasField('allowed_values'): json_attribute['allowedValues'] = convert_attr_value( attr.allowed_values) if attr.HasField('default_value'): json_attribute['default'] = convert_attr_value( attr.default_value) json_schema['attributes'].append(json_attribute) for input_arg in op.input_arg: if not 'inputs' in json_schema: json_schema['inputs'] = [] json_input = {} json_input['name'] = input_arg.name if input_arg.name in api_def_in_arg_map: api_def_in_arg = api_def_in_arg_map[input_arg.name] if api_def_in_arg.description: json_input['description'] = api_def_in_arg.description if input_arg.number_attr: json_input['numberAttr'] = input_arg.number_attr if input_arg.type: json_input['type'] = input_arg.type if input_arg.type_attr: json_input['typeAttr'] = input_arg.type_attr if input_arg.type_list_attr: json_input['typeListAttr'] = input_arg.type_list_attr if input_arg.is_ref: json_input['isRef'] = True json_schema['inputs'].append(json_input) for output_arg in op.output_arg: if not 'outputs' in json_schema: json_schema['outputs'] = [] json_output = {} json_output['name'] = output_arg.name if output_arg.name in api_def_out_arg_map: api_def_out_arg = api_def_out_arg_map[output_arg.name] if api_def_out_arg.description: json_output['description'] = api_def_out_arg.description if output_arg.number_attr: json_output['numberAttr'] = output_arg.number_attr if output_arg.type: json_output['type'] = output_arg.type elif output_arg.type_attr: json_output['typeAttr'] = output_arg.type_attr elif output_arg.type_list_attr: json_output['typeListAttr'] = output_arg.type_list_attr if output_arg.is_ref: json_output['isRef'] = True json_schema['outputs'].append(json_output) json_root.append({'name': op.name, 'schema': json_schema}) json_file = os.path.join(os.path.dirname(__file__), '../src/tf-metadata.json') with io.open(json_file, 'w', newline='') as fout: json_data = json.dumps(json_root, sort_keys=True, indent=2) for line in json_data.splitlines(): line = line.rstrip() if sys.version_info[0] < 3: line = unicode(line) fout.write(line) fout.write('\n')