def _CreatePythonApiDef(self, base_api_def, endpoint_names): """Creates Python ApiDef that overrides base_api_def if needed. Args: base_api_def: (api_def_pb2.ApiDef) base ApiDef instance. endpoint_names: List of Python endpoint names. Returns: api_def_pb2.ApiDef instance with overrides for base_api_def if module.name endpoint is different from any existing endpoints in base_api_def. Otherwise, returns None. """ endpoint_names_set = set(endpoint_names) base_endpoint_names_set = { self._GenerateLowerCaseOpName(endpoint.name) for endpoint in base_api_def.endpoint } if endpoint_names_set == base_endpoint_names_set: return None # All endpoints are the same api_def = api_def_pb2.ApiDef() api_def.graph_op_name = base_api_def.graph_op_name for endpoint_name in sorted(endpoint_names): new_endpoint = api_def.endpoint.add() new_endpoint.name = endpoint_name return api_def
def get_api_def(self, op_name): api_def_proto = api_def_pb2.ApiDef() buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) try: api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) finally: c_api.TF_DeleteBuffer(buf) return api_def_proto
def _AddHiddenOpOverrides(self, name_to_base_api_def, api_def_map): """Adds ApiDef overrides to api_def_map for hidden Python ops. Args: name_to_base_api_def: Map from op name to base api_def_pb2.ApiDef. api_def_map: Map from first op name character (in caps) to api_def_pb2.ApiDefs for Python API overrides. """ hidden_ops = _GetHiddenOps() for hidden_op in hidden_ops: if hidden_op not in name_to_base_api_def: logging.warning('Unexpected hidden op name: %s' % hidden_op) continue base_api_def = name_to_base_api_def[hidden_op] if base_api_def.visibility != api_def_pb2.ApiDef.HIDDEN: api_def = api_def_pb2.ApiDef() api_def.graph_op_name = base_api_def.graph_op_name api_def.visibility = api_def_pb2.ApiDef.HIDDEN api_def_map[api_def.graph_op_name[0].upper()].op.extend( [api_def])
def metadata(): categories = { 'Assign': 'Control', 'AvgPool': 'Pool', 'BatchNormWithGlobalNormalization': 'Normalization', 'BiasAdd': 'Layer', 'Concat': 'Tensor', 'ConcatV2': 'Tensor', 'Const': 'Constant', 'Conv2D': 'Layer', 'DepthwiseConv2dNative': 'Layer', 'Dequantize': 'Tensor', 'Elu': 'Activation', 'FusedBatchNorm': 'Normalization', 'FusedBatchNormV2': 'Normalization', 'FusedBatchNormV3': 'Normalization', 'Gather': 'Transform', 'Identity': 'Control', 'LeakyRelu': 'Activation', 'LRN': 'Normalization', 'LSTMBlockCell': 'Layer', 'MaxPool': 'Pool', 'MaxPoolV2': 'Pool', 'MaxPoolWithArgmax': 'Pool', 'Pad': 'Tensor', 'Relu': 'Activation', 'Relu6': 'Activation', 'Reshape': 'Shape', 'Sigmoid': 'Activation', 'Slice': 'Tensor', 'Softmax': 'Activation', 'Split': 'Tensor', 'Squeeze': 'Shape', 'StridedSlice': 'Tensor', 'swish_f32': 'Activation', 'Variable': 'Control', 'VariableV2': 'Control', } def find_multiline(line, colon): if colon == -1: return None line = line[colon + 1:] while line.startswith(' '): line = line[1:] if line.startswith('<<'): line = line[2:] return line return None def str_escape(text): result = '' for c in text: if (c == '\n'): result += '\\n' elif (c == '\r'): result += "\\r" elif (c == '\t'): result += "\\t" elif (c == '\"'): result += "\\\"" elif (c == '\''): result += "\\'" elif (c == '\\'): result += "\\\\" else: result += c return result def pbtxt_from_multiline(multiline_pbtxt): pbtxt = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') if index == -1: pbtxt = pbtxt + multiline_pbtxt multiline_pbtxt = '' break line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] colon = line.find(':') end = find_multiline(line, colon) if end == None: pbtxt = pbtxt + line + '\n' continue pbtxt = pbtxt + line[0:colon + 1] unescaped = '' newline = False line = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] if line.startswith(end): line = line[len(end):] break if newline: unescaped = unescaped + '\n' newline = True unescaped = unescaped + line line = '' pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n' return pbtxt def read_api_def_map(folder): api_def_map = {} file_list = os.listdir(folder) file_list = sorted(file_list) for filename in file_list: if filename.endswith('.pbtxt'): api_defs = api_def_pb2.ApiDefs() filename = folder + '/' + filename with open(filename) as handle: multiline_pbtxt = handle.read() pbtxt = pbtxt_from_multiline(multiline_pbtxt) text_format.Merge(pbtxt, api_defs) for api_def in api_defs.op: api_def_map[api_def.graph_op_name] = api_def return api_def_map def convert_type(type): return {'type': 'type', 'value': type} def convert_tensor(tensor): return {'type': 'tensor', 'value': '?'} def convert_shape(shape): return {'type': 'shape', 'value': '?'} def convert_number(number): if number == float('inf'): return 'NaN' if number == float('-inf'): return '-NaN' return number attr_type_table = { 'type': 'type', 'list(type)': 'type[]', 'bool': 'boolean', 'int': 'int64', 'list(int)': 'int64[]', 'float': 'float32', 'list(float)': 'float32[]', 'string': 'string', 'list(string)': 'string[]', 'shape': 'shape', 'list(shape)': 'shape[]', 'tensor': 'tensor', 'func': 'function', 'list(func)': 'function[]' } def convert_attr_type(type): if type in attr_type_table: return attr_type_table[type] print(type) return type def convert_attr_value(attr_value): if attr_value.HasField('list'): list = [] attr_value_list = attr_value.list if len(attr_value_list.s) > 0: for s in attr_value_list.s: list.append(s.decode('utf8')) if len(attr_value_list.i) > 0: for i in attr_value_list.i: list.append(i) if len(attr_value_list.f) > 0: for f in attr_value_list.f: list.append(convert_number(f)) if len(attr_value_list.type) > 0: for type in attr_value_list.type: list.append(convert_type(type)) if len(list) == 0: for _, value in attr_value_list.ListFields(): if len(value) > 0: raise Exception() return list if attr_value.HasField('s'): return attr_value.s.decode('utf8') if attr_value.HasField('i'): return attr_value.i if attr_value.HasField('f'): return convert_number(attr_value.f) if attr_value.HasField('b'): return attr_value.b if attr_value.HasField('type'): return convert_type(attr_value.type) if attr_value.HasField('tensor'): return convert_tensor(attr_value.tensor) if attr_value.HasField('shape'): return convert_shape(attr_value.shape) raise Exception() _TYPE_TO_STRING = { types_pb2.DataType.DT_HALF: "float16", types_pb2.DataType.DT_FLOAT: "float32", types_pb2.DataType.DT_DOUBLE: "float64", types_pb2.DataType.DT_INT32: "int32", types_pb2.DataType.DT_UINT8: "uint8", types_pb2.DataType.DT_UINT16: "uint16", types_pb2.DataType.DT_UINT32: "uint32", types_pb2.DataType.DT_UINT64: "uint64", types_pb2.DataType.DT_INT16: "int16", types_pb2.DataType.DT_INT8: "int8", types_pb2.DataType.DT_STRING: "string", types_pb2.DataType.DT_COMPLEX64: "complex64", types_pb2.DataType.DT_COMPLEX128: "complex128", types_pb2.DataType.DT_INT64: "int64", types_pb2.DataType.DT_BOOL: "bool", types_pb2.DataType.DT_QINT8: "qint8", types_pb2.DataType.DT_QUINT8: "quint8", types_pb2.DataType.DT_QINT16: "qint16", types_pb2.DataType.DT_QUINT16: "quint16", types_pb2.DataType.DT_QINT32: "qint32", types_pb2.DataType.DT_BFLOAT16: "bfloat16", types_pb2.DataType.DT_RESOURCE: "resource", types_pb2.DataType.DT_VARIANT: "variant", types_pb2.DataType.DT_HALF_REF: "float16_ref", types_pb2.DataType.DT_FLOAT_REF: "float32_ref", types_pb2.DataType.DT_DOUBLE_REF: "float64_ref", types_pb2.DataType.DT_INT32_REF: "int32_ref", types_pb2.DataType.DT_UINT32_REF: "uint32_ref", types_pb2.DataType.DT_UINT8_REF: "uint8_ref", types_pb2.DataType.DT_UINT16_REF: "uint16_ref", types_pb2.DataType.DT_INT16_REF: "int16_ref", types_pb2.DataType.DT_INT8_REF: "int8_ref", types_pb2.DataType.DT_STRING_REF: "string_ref", types_pb2.DataType.DT_COMPLEX64_REF: "complex64_ref", types_pb2.DataType.DT_COMPLEX128_REF: "complex128_ref", types_pb2.DataType.DT_INT64_REF: "int64_ref", types_pb2.DataType.DT_UINT64_REF: "uint64_ref", types_pb2.DataType.DT_BOOL_REF: "bool_ref", types_pb2.DataType.DT_QINT8_REF: "qint8_ref", types_pb2.DataType.DT_QUINT8_REF: "quint8_ref", types_pb2.DataType.DT_QINT16_REF: "qint16_ref", types_pb2.DataType.DT_QUINT16_REF: "quint16_ref", types_pb2.DataType.DT_QINT32_REF: "qint32_ref", types_pb2.DataType.DT_BFLOAT16_REF: "bfloat16_ref", types_pb2.DataType.DT_RESOURCE_REF: "resource_ref", types_pb2.DataType.DT_VARIANT_REF: "variant_ref", } def format_data_type(data_type): if data_type in _TYPE_TO_STRING: return _TYPE_TO_STRING[data_type] raise Exception() def format_attribute_value(value): if type(value ) is dict and 'type' in value and 'value' in value and value[ 'type'] == 'type': return format_data_type(value['value']) if type(value) is str: return value if value == True: return 'true' if value == False: return 'false' raise Exception() tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/source/tensorflow') api_def_map = read_api_def_map( os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api')) input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt') ops_list = op_def_pb2.OpList() with open(input_file) as input_handle: text_format.Merge(input_handle.read(), ops_list) json_root = [] for op in ops_list.op: # print(op.name) json_schema = {} json_schema['name'] = op.name if op.name in categories: json_schema['category'] = categories[op.name] api_def = api_def_pb2.ApiDef() if op.name in api_def_map: api_def = api_def_map[op.name] # if op.deprecation.version != 0: # print('[' + op.name + ']') # print(op.deprecation.version) # print(op.deprecation.explanation) api_def_attr_map = {} for attr in api_def.attr: api_def_attr_map[attr.name] = attr api_def_in_arg_map = {} for in_arg in api_def.in_arg: api_def_in_arg_map[in_arg.name] = in_arg api_def_out_arg_map = {} for out_arg in api_def.out_arg: api_def_out_arg_map[out_arg.name] = out_arg if api_def.summary: json_schema['summary'] = api_def.summary if api_def.description: json_schema['description'] = api_def.description for attr in op.attr: if not 'attributes' in json_schema: json_schema['attributes'] = [] json_attribute = {} json_attribute['name'] = attr.name attr_type = convert_attr_type(attr.type) if attr_type: json_attribute['type'] = attr_type else: del json_attribute['type'] if attr.name in api_def_attr_map: api_def_attr = api_def_attr_map[attr.name] if api_def_attr.description: json_attribute['description'] = api_def_attr.description if attr.has_minimum: json_attribute['minimum'] = attr.minimum if attr.HasField('allowed_values'): allowed_values = convert_attr_value(attr.allowed_values) description = json_attribute[ 'description'] + ' ' if 'description' in json_attribute else '' description = description + 'Must be one of the following: ' + ', '.join( list( map(lambda x: "`" + format_attribute_value(x) + "`", allowed_values))) + '.' json_attribute['description'] = description if attr.HasField('default_value'): default_value = convert_attr_value(attr.default_value) json_attribute['default'] = default_value json_schema['attributes'].append(json_attribute) for input_arg in op.input_arg: if not 'inputs' in json_schema: json_schema['inputs'] = [] json_input = {} json_input['name'] = input_arg.name if input_arg.name in api_def_in_arg_map: api_def_in_arg = api_def_in_arg_map[input_arg.name] if api_def_in_arg.description: json_input['description'] = api_def_in_arg.description if input_arg.number_attr: json_input['numberAttr'] = input_arg.number_attr if input_arg.type: json_input['type'] = input_arg.type if input_arg.type_attr: json_input['typeAttr'] = input_arg.type_attr if input_arg.type_list_attr: json_input['typeListAttr'] = input_arg.type_list_attr if input_arg.is_ref: json_input['isRef'] = True json_schema['inputs'].append(json_input) for output_arg in op.output_arg: if not 'outputs' in json_schema: json_schema['outputs'] = [] json_output = {} json_output['name'] = output_arg.name if output_arg.name in api_def_out_arg_map: api_def_out_arg = api_def_out_arg_map[output_arg.name] if api_def_out_arg.description: json_output['description'] = api_def_out_arg.description if output_arg.number_attr: json_output['numberAttr'] = output_arg.number_attr if output_arg.type: json_output['type'] = output_arg.type elif output_arg.type_attr: json_output['typeAttr'] = output_arg.type_attr elif output_arg.type_list_attr: json_output['typeListAttr'] = output_arg.type_list_attr if output_arg.is_ref: json_output['isRef'] = True json_schema['outputs'].append(json_output) json_root.append(json_schema) json_file = os.path.join(os.path.dirname(__file__), '../source/tf-metadata.json') with io.open(json_file, 'w', newline='') as fout: json_data = json.dumps(json_root, sort_keys=False, indent=2) for line in json_data.splitlines(): line = line.rstrip() fout.write(line) fout.write('\n')
def metadata(): categories = { 'Const': 'Constant', 'Conv2D': 'Layer', 'BiasAdd': 'Layer', 'DepthwiseConv2dNative': 'Layer', 'Relu': 'Activation', 'Relu6': 'Activation', 'Elu': 'Activation', 'Softmax': 'Activation', 'Sigmoid': 'Activation', 'LRN': 'Normalization', 'MaxPool': 'Pool', 'MaxPoolV2': 'Pool', 'AvgPool': 'Pool', 'Reshape': 'Shape', 'Squeeze': 'Shape', 'ConcatV2': 'Tensor', 'Split': 'Tensor', 'Dequantize': 'Tensor', 'Identity': 'Control', 'Variable': 'Control', 'VariableV2': 'Control', 'Assign': 'Control', 'BatchNormWithGlobalNormalization': 'Normalization', 'FusedBatchNorm': 'Normalization', # 'VariableV2': # 'Assign': # 'BiasAdd': } def find_multiline(line, colon): if colon == -1: return None line = line[colon + 1:] while line.startswith(' '): line = line[1:] if line.startswith('<<'): line = line[2:] return line return None def str_escape(text): result = '' for c in text: if (c == '\n'): result += '\\n' elif (c == '\r'): result += "\\r" elif (c == '\t'): result += "\\t" elif (c == '\"'): result += "\\\"" elif (c == '\''): result += "\\'" elif (c == '\\'): result += "\\\\" else: result += c return result def pbtxt_from_multiline(multiline_pbtxt): pbtxt = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') if index == -1: pbtxt = pbtxt + multiline_pbtxt multiline_pbtxt = '' break line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] colon = line.find(':') end = find_multiline(line, colon) if end == None: pbtxt = pbtxt + line + '\n' continue pbtxt = pbtxt + line[0:colon + 1] unescaped = '' newline = False line = '' while len(multiline_pbtxt) > 0: index = multiline_pbtxt.find('\n') line = multiline_pbtxt[0:index] multiline_pbtxt = multiline_pbtxt[index + 1:] if line.startswith(end): line = line[len(end):] break if newline: unescaped = unescaped + '\n' newline = True unescaped = unescaped + line line = '' pbtxt = pbtxt + '\"' + str_escape(unescaped) + '\"' + line + '\n' return pbtxt def read_api_def_map(folder): api_def_map = {} file_list = os.listdir(folder) file_list = sorted(file_list) for filename in file_list: api_defs = api_def_pb2.ApiDefs() filename = folder + '/' + filename with open(filename) as handle: multiline_pbtxt = handle.read() pbtxt = pbtxt_from_multiline(multiline_pbtxt) text_format.Merge(pbtxt, api_defs) for api_def in api_defs.op: api_def_map[api_def.graph_op_name] = api_def return api_def_map def convert_type(type): return {'type': 'type', 'value': type} def convert_tensor(tensor): return {'type': 'tensor', 'value': '?'} def convert_shape(shape): return {'type': 'shape', 'value': '?'} def convert_number(number): if number == float('inf'): return 'NaN' if number == float('-inf'): return '-NaN' return number attr_type_table = { 'type': 'type', 'list(type)': 'type[]', 'bool': 'boolean', 'int': 'int64', 'list(int)': 'int64[]', 'float': 'float32', 'list(float)': 'float32[]', 'string': 'string', 'list(string)': 'string[]', 'shape': 'shape', 'list(shape)': 'shape[]', 'tensor': 'tensor', 'func': 'function', 'list(func)': 'function[]' } def convert_attr_type(type): if type in attr_type_table: return attr_type_table[type] print(type) return type def convert_attr_value(attr_value): if attr_value.HasField('list'): list = [] attr_value_list = attr_value.list if len(attr_value_list.s) > 0: for s in attr_value_list.s: list.append(s.decode('utf8')) if len(attr_value_list.i) > 0: for i in attr_value_list.i: list.append(i) if len(attr_value_list.f) > 0: for f in attr_value_list.f: list.append(convert_number(f)) if len(attr_value_list.type) > 0: for type in attr_value_list.type: list.append(convert_type(type)) if len(list) == 0: for _, value in attr_value_list.ListFields(): if len(value) > 0: raise Exception() return list if attr_value.HasField('s'): return attr_value.s.decode('utf8') if attr_value.HasField('i'): return attr_value.i if attr_value.HasField('f'): return convert_number(attr_value.f) if attr_value.HasField('b'): return attr_value.b if attr_value.HasField('type'): return convert_type(attr_value.type) if attr_value.HasField('tensor'): return convert_tensor(attr_value.tensor) if attr_value.HasField('shape'): return convert_shape(attr_value.shape) raise Exception() tensorflow_repo_dir = os.path.join(os.path.dirname(__file__), '../third_party/src/tensorflow') api_def_map = read_api_def_map( os.path.join(tensorflow_repo_dir, 'tensorflow/core/api_def/base_api')) input_file = os.path.join(tensorflow_repo_dir, 'tensorflow/core/ops/ops.pbtxt') ops_list = op_def_pb2.OpList() with open(input_file) as input_handle: text_format.Merge(input_handle.read(), ops_list) json_root = [] for op in ops_list.op: # print(op.name) json_schema = {} if op.name in categories: json_schema['category'] = categories[op.name] api_def = api_def_pb2.ApiDef() if op.name in api_def_map: api_def = api_def_map[op.name] # if op.deprecation.version != 0: # print('[' + op.name + ']') # print(op.deprecation.version) # print(op.deprecation.explanation) api_def_attr_map = {} for attr in api_def.attr: api_def_attr_map[attr.name] = attr api_def_in_arg_map = {} for in_arg in api_def.in_arg: api_def_in_arg_map[in_arg.name] = in_arg api_def_out_arg_map = {} for out_arg in api_def.out_arg: api_def_out_arg_map[out_arg.name] = out_arg if api_def.summary: json_schema['summary'] = api_def.summary if api_def.description: json_schema['description'] = api_def.description for attr in op.attr: if not 'attributes' in json_schema: json_schema['attributes'] = [] json_attribute = {} json_attribute['name'] = attr.name attr_type = convert_attr_type(attr.type) if attr_type: json_attribute['type'] = attr_type else: del json_attribute['type'] if attr.name in api_def_attr_map: api_def_attr = api_def_attr_map[attr.name] if api_def_attr.description: json_attribute['description'] = api_def_attr.description if attr.has_minimum: json_attribute['minimum'] = attr.minimum if attr.HasField('allowed_values'): json_attribute['allowedValues'] = convert_attr_value( attr.allowed_values) if attr.HasField('default_value'): json_attribute['default'] = convert_attr_value( attr.default_value) json_schema['attributes'].append(json_attribute) for input_arg in op.input_arg: if not 'inputs' in json_schema: json_schema['inputs'] = [] json_input = {} json_input['name'] = input_arg.name if input_arg.name in api_def_in_arg_map: api_def_in_arg = api_def_in_arg_map[input_arg.name] if api_def_in_arg.description: json_input['description'] = api_def_in_arg.description if input_arg.number_attr: json_input['numberAttr'] = input_arg.number_attr if input_arg.type: json_input['type'] = input_arg.type if input_arg.type_attr: json_input['typeAttr'] = input_arg.type_attr if input_arg.type_list_attr: json_input['typeListAttr'] = input_arg.type_list_attr if input_arg.is_ref: json_input['isRef'] = True json_schema['inputs'].append(json_input) for output_arg in op.output_arg: if not 'outputs' in json_schema: json_schema['outputs'] = [] json_output = {} json_output['name'] = output_arg.name if output_arg.name in api_def_out_arg_map: api_def_out_arg = api_def_out_arg_map[output_arg.name] if api_def_out_arg.description: json_output['description'] = api_def_out_arg.description if output_arg.number_attr: json_output['numberAttr'] = output_arg.number_attr if output_arg.type: json_output['type'] = output_arg.type elif output_arg.type_attr: json_output['typeAttr'] = output_arg.type_attr elif output_arg.type_list_attr: json_output['typeListAttr'] = output_arg.type_list_attr if output_arg.is_ref: json_output['isRef'] = True json_schema['outputs'].append(json_output) json_root.append({'name': op.name, 'schema': json_schema}) json_file = os.path.join(os.path.dirname(__file__), '../src/tf-metadata.json') with io.open(json_file, 'w', newline='') as fout: json_data = json.dumps(json_root, sort_keys=True, indent=2) for line in json_data.splitlines(): line = line.rstrip() if sys.version_info[0] < 3: line = unicode(line) fout.write(line) fout.write('\n')
'../third_party/tensorflow/tensorflow/core/api_def/base_api') input_file = '../third_party/tensorflow/tensorflow/core/ops/ops.pbtxt' ops_list = op_def_pb2.OpList() with open(input_file) as input_handle: text_format.Merge(input_handle.read(), ops_list) json_root = [] for op in ops_list.op: # print(op.name) json_schema = {} if op.name in categories: json_schema['category'] = categories[op.name] api_def = api_def_pb2.ApiDef() if op.name in api_def_map: api_def = api_def_map[op.name] # if op.deprecation.version != 0: # print('[' + op.name + ']') # print(op.deprecation.version) # print(op.deprecation.explanation) api_def_attr_map = {} for attr in api_def.attr: api_def_attr_map[attr.name] = attr api_def_in_arg_map = {} for in_arg in api_def.in_arg: api_def_in_arg_map[in_arg.name] = in_arg api_def_out_arg_map = {} for out_arg in api_def.out_arg: api_def_out_arg_map[out_arg.name] = out_arg