def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx):
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    tvm_recommended_flags = get_ec2_tvm_flags()
    if ctx.device_type == 'gpu':
        flags = tvm_recommended_flags['g4']
    elif ctx.device_type == 'cpu':
        flags = tvm_recommended_flags['c4']
        if model_name != 'google_albert_base_v2':
            # Skip all other tests
            return
    else:
        raise NotImplementedError
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(
            model_name, root=root)
        cfg.defrost()
        cfg.MODEL.layout = layout
        cfg.freeze()
        model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path)
        model.hybridize()
        if layout == 'NT':
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (batch_size, seq_length),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (batch_size, seq_length),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        else:
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (seq_length, batch_size),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (seq_length, batch_size),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        if 'bart' in model_name:
            mx_out = model(token_ids, valid_length, token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
                'data2': token_ids.shape,
                'data3': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
                'data2': token_ids.dtype.name,
                'data3': valid_length.dtype.name,
            }
        elif 'roberta' in model_name or 'xlmr' in model_name:
            mx_out = model(token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
            }
        else:
            mx_out = model(token_ids, token_types, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': token_types.shape,
                'data2': valid_length.shape
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': token_types.dtype.name,
                'data2': valid_length.dtype.name
            }
        sym = model._cached_graph[1]
        params = {}
        for k, v in model.collect_params().items():
            params[v._var_name] = tvm.nd.array(v.data().asnumpy())
        mod, params = relay.frontend.from_mxnet(sym,
                                                shape=shape_dict,
                                                dtype=dtype_dict,
                                                arg_params=params)
        target = flags['target']
        use_gpu = flags['use_gpu']
        opt_level = flags['opt_level']
        required_pass = flags['required_pass']
        with tvm.transform.PassContext(opt_level=opt_level,
                                       required_pass=required_pass):
            lib = relay.build(mod, target, params=params)
        if use_gpu:
            ctx = tvm.gpu()
        else:
            ctx = tvm.cpu()
        rt = graph_runtime.GraphModule(lib["default"](ctx))
        if 'bart' in model_name:
            rt.set_input(data0=token_ids,
                         data1=valid_length,
                         data2=token_ids,
                         data3=valid_length)
        elif 'roberta' in model_name:
            rt.set_input(data0=token_ids, data1=valid_length)
        else:
            rt.set_input(data0=token_ids,
                         data1=token_types,
                         data2=valid_length)
        rt.run()
        for i in range(rt.get_num_outputs()):
            out = rt.get_output(i)
            if rt.get_num_outputs() == 1:
                mx_out_gt = mx_out.asnumpy()
            else:
                mx_out_gt = mx_out[i].asnumpy()
            npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=1e-3, atol=1e-1)
Exemple #2
0
    def _inference_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_fp16:
            os.environ['export MXNET_FC_TRUE_FP16'] = '1'
            dtype = 'float16'
        else:
            dtype = 'float32'
        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True, dtype=dtype)
        else:
            model = model_cls.from_cfg(cfg, dtype=dtype)
        model.load_parameters(backbone_param_path, ctx=ctx, cast_dtype=True)
        model.cast(dtype)
        model.hybridize(static_alloc=True, static_shape=True)
        vocab_size = cfg.MODEL.vocab_size
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
        else:
            raise NotImplementedError
        mxnet.npx.waitall()

        def run_forward():
            if 'roberta' in model_name or 'xlmr' in model_name:
                out = model(input_ids, valid_length)
            elif 'bart' in model_name:
                out = model(input_ids, valid_length, input_ids, valid_length)
            else:
                out = model(input_ids, token_types, valid_length)
            if isinstance(out, list):
                for ele in out:
                    ele.wait_to_read()
            else:
                out.wait_to_read()

        if self._use_tvm:
            tvm = try_import_tvm()
            run_forward()
            if self._use_gpu:
                ctx = tvm.gpu()
            else:
                ctx = tvm.cpu()
            rt = compile_tvm_graph_runtime(model=model, model_name=model_name,
                                           layout=self._layout, compute_layout=self._compute_layout,
                                           batch_size=batch_size, seq_length=sequence_length,
                                           instance_type=self._instance_type,
                                           dtype='float32' if not self._use_fp16 else 'float16')
            tvm_input_ids = tvm.nd.array(input_ids.asnumpy(), ctx=ctx)
            tvm_token_types = tvm.nd.array(token_types.asnumpy(), ctx=ctx)
            tvm_valid_length = tvm.nd.array(valid_length.asnumpy(), ctx=ctx)

            def run_tvm_forward():
                if 'roberta' in model_name or 'xlmr' in model_name:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_valid_length)
                elif 'bart' in model_name:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_valid_length)
                else:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_token_types,
                                 data2=tvm_valid_length)
                rt.run()
                for i in range(rt.get_num_outputs()):
                    out = rt.get_output(i)
            # Warmup
            timeit.repeat(run_tvm_forward, repeat=1, number=2)
            runtimes = timeit.repeat(run_tvm_forward, repeat=self._repeat, number=3)
        else:
            timeit.repeat(run_forward, repeat=1, number=3)
            runtimes = timeit.repeat(run_forward, repeat=self._repeat, number=3)
            mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            run_forward()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(run_forward)
            memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory
def tvm_enabled():
    try:
        tvm = try_import_tvm()
        return True
    except:
        return False
Exemple #4
0
def compile_tvm_graph_runtime(model, model_name, layout, compute_layout,
                              batch_size, seq_length, dtype, instance_type):
    key = (model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type)
    if key in _TVM_RT_CACHE:
        return _TVM_RT_CACHE[key]
    flags = get_ec2_tvm_flags()[instance_type]
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    token_ids_shape = (batch_size, seq_length) if layout == 'NT' else (seq_length, batch_size)
    valid_length_shape = (batch_size,)
    if 'bart' in model_name:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': valid_length_shape,
            'data2': token_ids_shape,
            'data3': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
            'data2': 'int32',
            'data3': 'int32',
        }
    elif 'roberta' in model_name or 'xlmr' in model_name:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
        }
    else:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': token_ids_shape,
            'data2': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
            'data2': 'int32'
        }
    sym = model._cached_graph[1]
    params = {}
    for k, v in model.collect_params().items():
        params[v._var_name] = tvm.nd.array(v.data().asnumpy())
    mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
    target = flags['target']
    use_gpu = flags['use_gpu']
    opt_level = flags['opt_level']
    required_pass = flags['required_pass']
    with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass):
        lib = relay.build(mod, target, params=params)
    if use_gpu:
        ctx = tvm.gpu()
    else:
        ctx = tvm.cpu()
    rt = graph_runtime.GraphModule(lib["default"](ctx))
    _TVM_RT_CACHE[key] = rt
    return rt