コード例 #1
0
ファイル: generate_project.py プロジェクト: tk26eng/blueoil
def run(input_path: str,
        dest_dir_path: str,
        project_name: str,
        activate_hard_quantization: bool,
        threshold_skipping: bool = False,
        debug: bool = False,
        cache_dma: bool = False):

    output_dlk_test_dir = path.join(dest_dir_path, f'{project_name}.test')
    optimized_pb_path = path.join(dest_dir_path, f'{project_name}')
    optimized_pb_path += '.pb'
    output_project_path = path.join(dest_dir_path, f'{project_name}.prj')

    config = Config(activate_hard_quantization=activate_hard_quantization,
                    threshold_skipping=threshold_skipping,
                    test_dir=output_dlk_test_dir,
                    optimized_pb_path=optimized_pb_path,
                    output_pj_path=output_project_path,
                    debug=debug,
                    cache_dma=cache_dma)

    dest_dir_path = path.abspath(dest_dir_path)
    util.make_dirs(dest_dir_path)

    click.echo('import pb file')
    io = TensorFlowIO()
    graph: Graph = io.read(input_path)

    click.echo('optimize graph step: start')
    optimize_graph_step(graph, config)
    click.echo('optimize graph step: done!')

    click.echo('generate code step: start')
    generate_code_step(graph, config)
    click.echo(f'generate code step: done!')
コード例 #2
0
ファイル: code_generator.py プロジェクト: 2429581027/blueoil
    def generate_files_from_template(self) -> None:
        src_dir_path = self.template.root_dir
        file_pathes = util.get_files(src_dir_path, excepts='/templates/manual')

        for src_file_path in file_pathes:
            src_file = Path(src_file_path)

            if src_file.is_file():
                relative_file_path = str(src_file.relative_to(src_dir_path))

                dest_file_path = path.join(self.config.output_pj_path,
                                           relative_file_path)
                dest_file_dir_path = path.dirname(dest_file_path)

                # if the file's dir not exist, make it
                util.make_dirs([dest_file_dir_path])

                if 'tpl' in path.basename(src_file_path) and path.basename(
                        src_file_path)[0] != '.':
                    relative_src_file_path = str(
                        src_file.relative_to(self.template.root_dir))
                    self.template.generate(relative_src_file_path,
                                           dest_file_dir_path)
                else:
                    shutil.copy2(src_file_path, dest_file_path)
コード例 #3
0
    def generate_inputs(self) -> None:
        input_src_dir_path = path.join(self.src_dir, 'inputs')
        input_header_dir_path = path.join(self.header_dir, 'inputs')
        util.make_dirs([input_src_dir_path, input_header_dir_path])

        input_src_template_path = path.join('consts', 'input.tpl.cpp')
        input_header_template_path = path.join('consts', 'input.tpl.h')

        for node in self.graph.consts:
            self.template.manual_generate(input_src_template_path,
                                          input_src_dir_path,
                                          new_name=node.name + '.cpp',
                                          node=node)

            self.template.manual_generate(input_header_template_path,
                                          input_header_dir_path,
                                          new_name=node.name + '.h',
                                          node=node)