Пример #1
0
mat_a = DenseMatrix(num_rows=56,
                    num_cols=56,
                    addressing=Addressing.NONE,
                    bbox=[0, 0, 20, 9])

mat_b = DenseMatrix(num_rows=56,
                    num_cols=56,
                    addressing=Addressing.STRIDED,
                    bbox=[0, 0, 9, 20])

tmp1 = generate_tmp_matrix(mat_a, mat_b)

gemm_list = [GemmDescr(trans_a=False,
                       trans_b=False,
                       a=mat_a, b=mat_b, c=tmp1),
             GemmDescr(trans_a=False,
                       trans_b=True,
                       a=tmp1, b=mat_b, c=mat_q)]

vm = vm_factory(name='nvidia',
                sub_name='sm_60',
                fp_type=FloatingPointType.FLOAT)

generator = Generator(gemm_list, vm)
generator.generate()

print(generator.get_launcher())
print()
print(generator.get_header())
print()
print(generator.get_kernel())
Пример #2
0
def main():
    cmd = argparse.ArgumentParser()
    cmd.add_argument('-i', '--input', type=str, help="input file")
    cmd.add_argument('-c', '--config', type=str, help="config file")
    cmd.add_argument('-a', '--arch', type=str, help='gpu arch (nvidia, amd)')
    cmd.add_argument('-s',
                     '--sub_arch',
                     type=str,
                     help='sub architecture (nvidia, amd)')
    args = cmd.parse_args()

    try:
        does_file_exist(args.input, format='.cf')
        does_file_exist(args.config, format='.yaml')
    except ValueError as err:
        print(f'{err}')
        sys.exit(-1)

    curr_dir = path.join(path.dirname(__file__))
    with open(f'{curr_dir}/{args.input}', 'r') as file:
        program = file.read()

    # get AST and append symbol table with temporaries
    parser = Parser()
    ast, symbol_table = parser.parse(translation_unit=program)

    # convert AST to lists of gemms
    symbol_table.add_scope()
    processor = PostProcessor(ast, symbol_table)
    gemm_dicts = processor.process()

    stream = open(args.config, 'r')
    config = yaml.safe_load(stream)
    vm = vm_factory(name=args.arch,
                    sub_name=args.sub_arch,
                    fp_type=FloatingPointType.str2enum(config['fp_type']))

    kernels = []
    launchers = []
    headers = []
    benchmarks_src = []
    benchmarks_names = []
    for bench_name, gemm_list in gemm_dicts.items():
        gpu_generator = Generator(gemm_list, vm)
        gpu_generator.set_kernel_name(bench_name)
        gpu_generator.generate()

        # write kernel, launcher and header to files
        kernels.append(gpu_generator.get_kernel())
        launchers.append(gpu_generator.get_launcher())
        headers.append(gpu_generator.get_header())

        call_site = Aux.get_call_site(gpu_generator, gemm_list)
        bench_generator = BenchGenerator(bench_name, deepcopy(symbol_table),
                                         gemm_list, config, call_site,
                                         args.arch)

        benchmarks_names.append(bench_name)
        benchmarks_src.append(bench_generator.generate())

    # generate main file
    enty_point = EnryPointGenerator(benchmarks_names, benchmarks_src, config)
    main_src = enty_point.generate()
    tmp_dir = make_tmp_folder(file_name='tmp')
    with open(path.join(tmp_dir, 'main.cu'), 'w') as file:
        file.write(main_src)

    # write kernel, launcher and header to files
    with open(path.join(tmp_dir, 'kernel.cu'), 'w') as file:
        file.write('#include \"chainforge_aux.h\"\n')
        for kernel, launcher in zip(kernels, launchers):
            file.write(kernel)
            file.write(launcher)

    with open(path.join(tmp_dir, 'kernel.h'), 'w') as file:
        file.write('#ifndef KERNEL_H\n')
        file.write('#define KERNEL_H\n')
        for header in headers:
            file.write(header)
        file.write('#endif\n')

    with open(path.join(tmp_dir, 'cmake_params.cmake'), 'w') as file:
        real_size = 8 if config['fp_type'] == 'double' else 4
        file.write(f'set(REAL_SIZE {real_size})\n')
        file.write(f'set(ARCH {args.arch})\n')
        file.write(f'set(SM_ARCH {args.sub_arch})\n')