예제 #1
0
  def dump_xmodel(self):
    if self.quant_mode < 2:
      return

    compiler = CompilerFactory.get_compiler("xmodel")
    xmodel_dir = os.path.join(self._output_dir, "xmodel")
    generic_utils.mkdir_if_not_exist(xmodel_dir)

    compile_args = []
    if self.lstm:
      for direction, graph in self._cell_graphs:
        compile_args.append((graph, {'direction': direction}))
    else:
      compile_args.append((graph, {}))

    for graph, attr_kwargs in compile_args:
      for node in graph.nodes:
        # TODO(yuwang): Set out tensor shape in parser.
        # Maybe specify shape for all tensors?
        # Input shape must be specified for xgraph shape inference.
        if node.op.type == OpTypes.INPUT:
          node.out_tensors[0].shape = node.op.attr['shape']
      try:
        compiler.do_compile(
            graph,
            os.path.join(xmodel_dir, graph.name),
            quant_config_info=self.quant_config,
            graph_attr_kwargs=attr_kwargs)
      except Exception as e:
        print('[ERROR] Failed to dump xmodel: {}'.format(e))
        return

    print('[INFO] Successfully convert nndct graph to xmodel!')
예제 #2
0
파일: utils.py 프로젝트: Xilinx/Vitis-AI
def write_proto(path, message, as_text=False):
    dir_name = os.path.dirname(path)
    generic_utils.mkdir_if_not_exist(dir_name)
    if dir_name:
        os.makedirs(dir_name, exist_ok=True)
    if as_text:
        with open(path, "w") as f:
            f.write(text_format.MessageToString(message))
    else:
        with open(path, "wb") as f:
            f.write(message.SerializeToString())
예제 #3
0
def maybe_export_graph(path, graph):
  if not os.environ.get('EXPORT_NNDCT_TF_PARSER_INTERNAL', ''):
    return
  dir_name = os.path.dirname(path)
  generic_utils.mkdir_if_not_exist(dir_name)

  if isinstance(graph, tf.Graph):
    graph = graph.as_graph_def()
    write_binary_proto(path, graph)
  elif isinstance(graph, graph_pb2.GraphDef):
    write_binary_proto(path, graph)
  elif isinstance(graph, ops.Graph):
    viz.export_to_netron(path, graph)
  else:
    pass
예제 #4
0
    def write(self, filepath):
        self._printer.print('# Generated by NNDCT. Do not edit!')
        self._printer.newline()

        self.write_imports()
        self.write_class_def()

        generic_utils.mkdir_if_not_exist(os.path.dirname(filepath))
        with open(filepath, 'w') as f:
            f.write(self._printer.get())
            f.flush()
            os.fsync(f.fileno())

        layer_to_node = {}
        for node in self._graph.nodes:
            entity = self._translator.get_entity(node.name)
            if entity.need_init:
                layer_to_node[entity.name] = node

        return layer_to_node