def main(): parser = argparse.ArgumentParser( description='generate elemwise impl files', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=['cuda'], default='cuda', help='generate cuda cond take kernel file') parser.add_argument('output', help='output directory') args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) assert args.type == 'cuda' cpp_ext = 'cu' for dtype in DTYPES.keys(): fname = '{}.{}'.format(dtype, cpp_ext) fname = os.path.join(args.output, fname) with open(fname, 'w') as fout: w = lambda s: print(s, file=fout) w('// generated by gen_cond_take_kern_impls.py') w('#include "../kern.inl"') w('') if dtype == 'dt_float16': w('#if !MEGDNN_DISABLE_FLOAT16') w('namespace megdnn {') w('namespace cuda {') w('namespace cond_take {') w('') w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) w('#undef inst_genidx') w('') w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) w('#undef inst_copy') w('#undef inst_copy_') w('') w('} // cond_take') w('} // cuda') w('} // megdnn') if dtype == 'dt_float16': w('#endif') print('generated {}'.format(fname)) os.utime(args.output)
def main(): parser = argparse.ArgumentParser( description='generate elemwise impl files', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=['cuda', 'hip', 'cpp'], default='cpp', help='generate cuda/hip kernel file') parser.add_argument('output', help='output directory') args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) if args.type == 'cuda': cpp_ext = 'cu' elif args.type == 'hip': cpp_ext = 'cpp.hip' else: assert args.type == 'cpp' cpp_ext = 'cpp' for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): for mode in MODES[(anum, DTYPES[ctype][1])]: formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) fname = os.path.join(args.output, fname) with open(fname, 'w') as fout: w = lambda s: print(s, file=fout) w('// generated by gen_elemwise_kern_impls.py') if ctype == 'dt_float16' or ctype == 'dt_bfloat16': w('#if !MEGDNN_DISABLE_FLOAT16') w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) w('#define KERN_IMPL_ARITY {}'.format(anum)) w('#define KERN_IMPL_CTYPE {}'.format(ctype)) w('#include "../kern_impl.inl"') if ctype == 'dt_float16' or ctype == 'dt_bfloat16': w('#endif') print('generated {}'.format(fname)) os.utime(args.output)
def main(): parser = argparse.ArgumentParser( description='generate elemwise impl files', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=['cuda', 'hip'], default='cuda', help='generate cuda/hip elemwise special kernel file') parser.add_argument('output', help='output directory') args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) if args.type == 'cuda': cpp_ext = 'cu' else: assert args.type == 'hip' cpp_ext = 'cpp.hip' for dtype in DTYPES.keys(): fname = 'special_{}.{}'.format(dtype, cpp_ext) fname = os.path.join(args.output, fname) with open(fname, 'w') as fout: w = lambda s: print(s, file=fout) w('// generated by gen_elemwise_special_kern_impls.py') if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#if !MEGDNN_DISABLE_FLOAT16') w('#include "../special_kerns.inl"') w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) w('#undef INST') w('}') w('}') if dtype == 'dt_float16' or dtype == 'dt_bfloat16': w('#endif') print('generated {}'.format(fname)) os.utime(args.output)