Пример #1
0
def main():
    parser = argparse.ArgumentParser(
        description='generate elemwise impl files',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--type',
                        type=str,
                        choices=['cuda'],
                        default='cuda',
                        help='generate cuda cond take kernel file')
    parser.add_argument('output', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    assert args.type == 'cuda'
    cpp_ext = 'cu'

    for dtype in DTYPES.keys():
        fname = '{}.{}'.format(dtype, cpp_ext)
        fname = os.path.join(args.output, fname)
        with open(fname, 'w') as fout:
            w = lambda s: print(s, file=fout)

            w('// generated by gen_cond_take_kern_impls.py')
            w('#include "../kern.inl"')
            w('')
            if dtype == 'dt_float16':
                w('#if !MEGDNN_DISABLE_FLOAT16')
            w('namespace megdnn {')
            w('namespace cuda {')
            w('namespace cond_take {')
            w('')

            w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
            w('#undef inst_genidx')
            w('')
            w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
            w('#undef inst_copy')
            w('#undef inst_copy_')

            w('')
            w('}  // cond_take')
            w('}  // cuda')
            w('}  // megdnn')
            if dtype == 'dt_float16':
                w('#endif')

            print('generated {}'.format(fname))

    os.utime(args.output)
def main():
    parser = argparse.ArgumentParser(
        description='generate elemwise impl files',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--type',
                        type=str,
                        choices=['cuda', 'hip', 'cpp'],
                        default='cpp',
                        help='generate cuda/hip kernel file')
    parser.add_argument('output', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    if args.type == 'cuda':
        cpp_ext = 'cu'
    elif args.type == 'hip':
        cpp_ext = 'cpp.hip'
    else:
        assert args.type == 'cpp'
        cpp_ext = 'cpp'

    for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
        for mode in MODES[(anum, DTYPES[ctype][1])]:
            formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
            fname = '{}_{}.{}'.format(mode, ctype, cpp_ext)
            fname = os.path.join(args.output, fname)
            with open(fname, 'w') as fout:
                w = lambda s: print(s, file=fout)
                w('// generated by gen_elemwise_kern_impls.py')

                if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
                    w('#if !MEGDNN_DISABLE_FLOAT16')

                w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
                w('#define KERN_IMPL_ARITY {}'.format(anum))
                w('#define KERN_IMPL_CTYPE {}'.format(ctype))
                w('#include "../kern_impl.inl"')

                if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
                    w('#endif')

            print('generated {}'.format(fname))

    os.utime(args.output)
def main():
    parser = argparse.ArgumentParser(
        description='generate elemwise impl files',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--type',
                        type=str,
                        choices=['cuda', 'hip'],
                        default='cuda',
                        help='generate cuda/hip elemwise special kernel file')
    parser.add_argument('output', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    if args.type == 'cuda':
        cpp_ext = 'cu'
    else:
        assert args.type == 'hip'
        cpp_ext = 'cpp.hip'

    for dtype in DTYPES.keys():
        fname = 'special_{}.{}'.format(dtype, cpp_ext)
        fname = os.path.join(args.output, fname)
        with open(fname, 'w') as fout:
            w = lambda s: print(s, file=fout)

            w('// generated by gen_elemwise_special_kern_impls.py')
            if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
                w('#if !MEGDNN_DISABLE_FLOAT16')
            w('#include "../special_kerns.inl"')
            w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
            w('#undef INST')
            w('}')
            w('}')
            if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
                w('#endif')

            print('generated {}'.format(fname))

    os.utime(args.output)