def dump_xmodel(self): if self.quant_mode < 2: return compiler = CompilerFactory.get_compiler("xmodel") xmodel_dir = os.path.join(self._output_dir, "xmodel") generic_utils.mkdir_if_not_exist(xmodel_dir) compile_args = [] if self.lstm: for direction, graph in self._cell_graphs: compile_args.append((graph, {'direction': direction})) else: compile_args.append((graph, {})) for graph, attr_kwargs in compile_args: for node in graph.nodes: # TODO(yuwang): Set out tensor shape in parser. # Maybe specify shape for all tensors? # Input shape must be specified for xgraph shape inference. if node.op.type == OpTypes.INPUT: node.out_tensors[0].shape = node.op.attr['shape'] try: compiler.do_compile( graph, os.path.join(xmodel_dir, graph.name), quant_config_info=self.quant_config, graph_attr_kwargs=attr_kwargs) except Exception as e: print('[ERROR] Failed to dump xmodel: {}'.format(e)) return print('[INFO] Successfully convert nndct graph to xmodel!')
def write_proto(path, message, as_text=False): dir_name = os.path.dirname(path) generic_utils.mkdir_if_not_exist(dir_name) if dir_name: os.makedirs(dir_name, exist_ok=True) if as_text: with open(path, "w") as f: f.write(text_format.MessageToString(message)) else: with open(path, "wb") as f: f.write(message.SerializeToString())
def maybe_export_graph(path, graph): if not os.environ.get('EXPORT_NNDCT_TF_PARSER_INTERNAL', ''): return dir_name = os.path.dirname(path) generic_utils.mkdir_if_not_exist(dir_name) if isinstance(graph, tf.Graph): graph = graph.as_graph_def() write_binary_proto(path, graph) elif isinstance(graph, graph_pb2.GraphDef): write_binary_proto(path, graph) elif isinstance(graph, ops.Graph): viz.export_to_netron(path, graph) else: pass
def write(self, filepath): self._printer.print('# Generated by NNDCT. Do not edit!') self._printer.newline() self.write_imports() self.write_class_def() generic_utils.mkdir_if_not_exist(os.path.dirname(filepath)) with open(filepath, 'w') as f: f.write(self._printer.get()) f.flush() os.fsync(f.fileno()) layer_to_node = {} for node in self._graph.nodes: entity = self._translator.get_entity(node.name) if entity.need_init: layer_to_node[entity.name] = node return layer_to_node