def _get_attribute(attribute_proto): if attribute_proto.HasField('ref_attr_name'): raise ParseException('Unexpected ref_attr_name in main graph') name = _fixstr(attribute_proto.name) if attribute_proto.HasField('f'): value = float(attribute_proto.f) elif attribute_proto.HasField('i'): value = utils.anyint_to_int(attribute_proto.i) elif attribute_proto.HasField('s'): value = utils.anystr_to_str(attribute_proto.s) elif attribute_proto.HasField('t'): value = _get_tensor(attribute_proto.t) elif attribute_proto.HasField('g'): value = _get_graph(attribute_proto.g) elif attribute_proto.floats: value = [float(f) for f in attribute_proto.floats] elif attribute_proto.ints: value = [utils.anyint_to_int(i) for i in attribute_proto.ints] elif attribute_proto.strings: value = [utils.anystr_to_str(s) for s in attribute_proto.strings] elif attribute_proto.tensors: value = [_get_tensor(t) for t in attribute_proto.tensors] elif attribute_proto.graphs: value = [_get_graph(g) for g in attribute_proto.graphs] else: value = [] return name, value
def _get_attribute(attribute_proto): if attribute_proto.HasField('ref_attr_name'): raise ParseException('Unexpected ref_attr_name in main graph') name = _fixstr(attribute_proto.name) if attribute_proto.HasField('f'): value = float(attribute_proto.f) elif attribute_proto.HasField('i'): value = utils.anyint_to_int(attribute_proto.i) elif attribute_proto.HasField('s'): value = utils.anystr_to_str(attribute_proto.s) elif attribute_proto.HasField('t'): value = _get_tensor(attribute_proto.t) # raise ParseException("Attribute '{}' with type TENSOR in unsupported".format(name)) elif attribute_proto.HasField('g'): value = _get_graph(attribute_proto.g) # raise ParseException("Attribute '{}' with type GRAPH in unsupported".format(name)) elif attribute_proto.floats: value = [float(f) for f in attribute_proto.floats] elif attribute_proto.ints: value = [utils.anyint_to_int(i) for i in attribute_proto.ints] elif attribute_proto.strings: value = [utils.anystr_to_str(s) for s in attribute_proto.strings] elif attribute_proto.tensors: # raise ParseException("Attribute '{}' with type TENSOR LIST in unsupported".format(name)) value = [_get_tensor(t) for t in attribute_proto.tensors] elif attribute_proto.graphs: # raise ParseException("Attribute '{}' with type GRAPH LIST in unsupported".format(name)) value = [_get_graph(g) for g in attribute_proto.graphs] else: value = [] return name, value
def run_test(output_name, output_nodes, recreate=True): batch_size = 1 source_shapes = { ph.name: [ int(d.value) if d.value is not None else batch_size for d in ph.shape.dims ] for ph in get_placeholders() } pb_path = os.path.join('out', 'pb', output_name) network_name = output_name.rstrip('/').replace('/', '_') with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) save_protobuf(pb_path, output_nodes, sess, recreate) sess.close() for prefer_nhwc in [True, False]: test_activations(filename=os.path.join(pb_path, 'test.pb'), source_shapes=source_shapes, feed_dict={ utils.anystr_to_str(k): np.random.random(v) for k, v in six.iteritems(source_shapes) }, prefer_nhwc=prefer_nhwc, network_name=network_name, delete_after_each=False, export_io_only=True) return nnef.parse_file( os.path.join('out', network_name + '_nhwc', 'nnef', network_name + '_nnef', 'graph.nnef'))
def _get_attributes(attr_map_proto, graph): attributes = {} for name, value in attr_map_proto.items(): if not name.startswith('_'): field = value.WhichOneof('value') value = getattr(value, field) attributes[utils.anystr_to_str(name)] = _get_attribute(field, value, graph) return attributes
def write( tf_graph, # type: TFGraph file_path, # type: str write_weights=True, # type: bool custom_op_protos=None, # type: typing.Optional[typing.List[OpProto]] custom_imports=None # type: str ): # type: (...) -> typing.Optional[conversion_info.ConversionInfo] generate_source_operations(tf_graph) tf_graph.sort() try: names_to_write = _generate_names(tf_graph=tf_graph, custom_imports=custom_imports, custom_op_protos=custom_op_protos) old_names = {} for tensor in tf_graph.tensors: old_names[tensor] = tensor.name tensor.name = utils.anystr_to_str(names_to_write[tensor]) if not os.path.exists(os.path.dirname(file_path)): os.makedirs(os.path.dirname(file_path)) with open(file_path, "w") as f: _print(tf_graph, file_handle=f, custom_op_protos=custom_op_protos, custom_imports=custom_imports) with open(file_path, "r") as f: tf_source = f.read() if tf_graph.list_variables() and write_weights: checkpoint_dir = file_path + ".checkpoint" checkpoint_path = os.path.join( checkpoint_dir, os.path.basename(file_path) + ".ckpt") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) _create_checkpoint_with_values(net_fun=_tfsource_to_function( tf_source, tf_graph.name), file_name=checkpoint_path, variable_value_by_name={ t.name: t.data for t in tf_graph.tensors if t.is_variable and t.name }) for tensor in tf_graph.tensors: tensor.name = old_names[tensor] return _get_rename_info(tf_graph, names_to_write) finally: remove_source_operations(tf_graph)
def _get_attribute(field, value, graph): if field == 'i' or field == 'f' or field == 'b' or field == 'placeholder': if utils.is_anyint(value): return utils.anyint_to_int(value) return value elif field == 's': return utils.anystr_to_str(value.decode()) elif field == 'shape': return _get_shape(value) elif field == 'type': return _get_dtype(value) elif field == 'tensor': return _get_tensor(value, graph) elif field == 'func': return _get_func(value) elif field == 'list': field, items = _get_nonempty_items(value, fields=['i', 'f', 'b', 's', 'shape', 'type', 'tensor', 'func']) if items is None: return [] return [_get_attribute(field, item, graph) for item in items] assert False
def _normalize_types(arg): if utils.is_anyint(arg): return utils.anyint_to_int(arg) elif utils.is_anystr(arg): return utils.anystr_to_str(arg) elif isinstance(arg, np.ndarray): return arg.tolist() elif isinstance(arg, tf.TensorShape): if arg.dims is None: return None return [None if dim is None else int(dim) for dim in arg.as_list()] elif isinstance(arg, tf.Dimension): return arg.value elif isinstance(arg, tf.DType): return arg.name elif isinstance( arg, (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_)): return arg.item() else: return arg
def read_tf_graph_from_protobuf(filename): graph_def = tf_pb.GraphDef() with open(filename, 'rb') as file: graph_def.ParseFromString(file.read()) graph = TFGraph() attrib_graph = TFGraph() # just a graph to contain the tensors that are in attributes, no need to return this attributes_by_node_name = {} outputs_by_node_name = {} detected_output_count = _detect_output_counts(graph_def) for node in graph_def.node: outputs = [] attributes = _get_attributes(node.attr, attrib_graph) output_count = _OutputCount.get(node.op, 1) if isinstance(output_count, str): output_count = attributes[output_count] output_count = max(output_count, detected_output_count.get(node.op, 1)) assert isinstance(output_count, int) if output_count >= 1: output = TFTensor(graph, utils.anystr_to_str(node.name)) outputs.append(output) for i in range(1, output_count): tensor_name = utils.anystr_to_str(node.name) + ':' + str(i) output = TFTensor(graph, tensor_name) outputs.append(output) outputs_by_node_name[node.name] = outputs attributes_by_node_name[node.name] = attributes tensor_by_name = {tensor.name: tensor for outputs in six.itervalues(outputs_by_node_name) for tensor in outputs} placeholders = [] for node in graph_def.node: attributes = attributes_by_node_name[node.name] outputs = outputs_by_node_name[node.name] if node.op == 'Placeholder': assert len(outputs) == 1 tensor = outputs[0] tensor.shape = attributes['shape'] if 'shape' in attributes else None tensor.dtype = attributes['dtype'] if 'dtype' in attributes else None placeholders.append(tensor) elif node.op == 'Const': assert len(outputs) == 1 tensor = outputs[0] value = attributes['value'] if isinstance(value, TFTensor): tensor.shape = value.shape tensor.dtype = value.dtype tensor.data = value.data else: tensor.data = value else: input_names = [name[:-2] if name.endswith(':0') else name for name in node.input if not name.startswith('^')] for name in input_names: if name not in tensor_by_name: print('Info: List of node types in graph: {}\n'.format( sorted(list({node.op for node in graph_def.node})))) raise utils.NNEFToolsException( "Tensor {} is used, but it is not clear which operation produced it. " "Probably the graph has unsupported dynamic operations.".format(name)) inputs = tuple([tensor_by_name[name] for name in input_names]) TFOperation(graph, name=utils.anystr_to_str(node.op), inputs=inputs, outputs=outputs, attribs=attributes) for tensor in graph.tensors: if tensor.name is not None and ':' not in tensor.name: tensor.name += ':0' graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders]) graph_outputs = [] for op in graph.operations: if all(len(output.consumers) == 0 for output in op.outputs): for output in op.outputs: graph_outputs.append(output) graph.outputs = OrderedDict([('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor) for i, tensor in enumerate(graph_outputs)]) return graph
def read_tflite_graph_from_flatbuffers(filename): with open(filename, 'rb') as file: bytes = bytearray(file.read()) model = tflite_fb.Model.GetRootAsModel(bytes, 0) if model.SubgraphsLength() != 1: raise NotImplementedError( 'graphs with multiple sub-graphs are not supported') subgraph = model.Subgraphs(0) name = subgraph.Name() graph = TFGraph(name.decode() if name is not None else None) tensors = [] for i in range(subgraph.TensorsLength()): tensor = subgraph.Tensors(i) name = tensor.Name().decode() shape = [tensor.Shape(i) for i in range(tensor.ShapeLength())] dtype = _TensorTypeNameByValue[tensor.Type()] buffer = model.Buffers(tensor.Buffer()) data = _get_data_as_ndarray(buffer, _TensorDtypeAsNumpy[tensor.Type()], shape) quant = _get_quantization(tensor) label = name if data is not None else None tensors.append( TFTensor(graph, utils.anystr_to_str(name), shape, dtype, data, utils.anystr_to_str(label) if label is not None else None, quant)) for i in range(subgraph.OperatorsLength()): operator = subgraph.Operators(i) operatorCode = model.OperatorCodes(operator.OpcodeIndex()) name = _BuiltinOperatorNameByValue[operatorCode.BuiltinCode()] options = operator.BuiltinOptions() optionsClass = _BuiltinOptionsClasses[operator.BuiltinOptionsType()] inputs = [ tensors[operator.Inputs(i)] for i in range(operator.InputsLength()) if operator.Inputs(i) != -1 ] outputs = [ tensors[operator.Outputs(i)] for i in range(operator.OutputsLength()) if operator.Outputs(i) != -1 ] if optionsClass is not None: optionsObject = optionsClass() optionsObject.Init(options.Bytes, options.Pos) attribs = _enumerate_attributes(optionsClass, optionsObject) else: attribs = {} if operatorCode.BuiltinCode() == tflite_fb.BuiltinOperator.CUSTOM: assert _custom_op_type_key not in attribs, \ "'{}' shall not be set as an attribute".format(_custom_op_type_key) attribs[_custom_op_type_key] = operatorCode.CustomCode().decode( 'ascii') assert _custom_op_options_key not in attribs, \ "'{}' shall not be set as an attribute".format(_custom_op_options_key) attribs[_custom_op_options_key] = operator.CustomOptionsAsNumpy( ).tolist() TFOperation(graph, name, inputs, outputs, attribs) inputs = [] for i in range(subgraph.InputsLength()): tensor_index = subgraph.Inputs(i) inputs.append(tensors[tensor_index]) outputs = [] for i in range(subgraph.OutputsLength()): tensor_index = subgraph.Outputs(i) outputs.append(tensors[tensor_index]) graph.inputs = inputs graph.outputs = outputs return graph
def get_feed_dict(): placeholders = get_placeholders() feed_dict = {} for p in placeholders: feed_dict[utils.anystr_to_str(p.name)] = np.random.random(p.shape) return feed_dict
def fixstr(s): return utils.anystr_to_str(s) if s is not None else None
def read_tf_graph_from_protobuf(filename): graph_def = tf_pb.GraphDef() with open(filename, 'rb') as file: graph_def.ParseFromString(file.read()) graph = TFGraph() # just a graph to contain the tensors that are in attributes # no need to return this attrib_graph = TFGraph() attributes_by_node_id = {} outputs_by_node_id = {} for node in graph_def.node: outputs = [] attributes = _get_attributes(node.attr, attrib_graph) output_count = _OutputCount.get(node.op, 1) if isinstance(output_count, str): output_count = attributes[output_count] assert isinstance(output_count, int) if output_count >= 1: output = TFTensor(graph, utils.anystr_to_str(node.name)) outputs.append(output) for i in range(1, output_count): tensor_name = utils.anystr_to_str(node.name) + ':' + str(i) output = TFTensor(graph, tensor_name) outputs.append(output) outputs_by_node_id[id(node)] = outputs attributes_by_node_id[id(node)] = attributes tensor_by_name = { tensor.name: tensor for outputs in six.itervalues(outputs_by_node_id) for tensor in outputs } placeholders = [] for node in graph_def.node: attributes = attributes_by_node_id[id(node)] outputs = outputs_by_node_id[id(node)] if node.op == 'Placeholder': assert len(outputs) == 1 tensor = outputs[0] tensor.shape = attributes[ 'shape'] if 'shape' in attributes else None tensor.dtype = attributes[ 'dtype'] if 'dtype' in attributes else None placeholders.append(tensor) elif node.op == 'Const': assert len(outputs) == 1 tensor = outputs[0] value = attributes['value'] if isinstance(value, TFTensor): tensor.shape = value.shape tensor.dtype = value.dtype tensor.data = value.data else: tensor.data = value else: inputs = tuple([tensor_by_name[name] for name in node.input]) TFOperation(graph, name=utils.anystr_to_str(node.op), inputs=inputs, outputs=outputs, attribs=attributes) for tensor in graph.tensors: if tensor.name is not None and ':' not in tensor.name: tensor.name += ':0' graph.inputs = OrderedDict([(tensor.name.split(':')[0], tensor) for tensor in placeholders]) graph_outputs = [] for op in graph.operations: if all(len(output.consumers) == 0 for output in op.outputs): for output in op.outputs: graph_outputs.append(output) graph.outputs = OrderedDict([ ('output' + str(i) if len(graph_outputs) > 1 else 'output', tensor) for i, tensor in enumerate(graph_outputs) ]) return graph