def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx): tvm = try_import_tvm() from tvm import relay from tvm.contrib import graph_runtime tvm_recommended_flags = get_ec2_tvm_flags() if ctx.device_type == 'gpu': flags = tvm_recommended_flags['g4'] elif ctx.device_type == 'cpu': flags = tvm_recommended_flags['c4'] if model_name != 'google_albert_base_v2': # Skip all other tests return else: raise NotImplementedError with tempfile.TemporaryDirectory() as root, ctx: model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone( model_name, root=root) cfg.defrost() cfg.MODEL.layout = layout cfg.freeze() model = model_cls.from_cfg(cfg) model.load_parameters(backbone_param_path) model.hybridize() if layout == 'NT': token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32) token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size, ), dtype=np.int32) else: token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (seq_length, batch_size), dtype=np.int32) token_types = mx.np.random.randint(0, 2, (seq_length, batch_size), dtype=np.int32) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size, ), dtype=np.int32) if 'bart' in model_name: mx_out = model(token_ids, valid_length, token_ids, valid_length) shape_dict = { 'data0': token_ids.shape, 'data1': valid_length.shape, 'data2': token_ids.shape, 'data3': valid_length.shape, } dtype_dict = { 'data0': token_ids.dtype.name, 'data1': valid_length.dtype.name, 'data2': token_ids.dtype.name, 'data3': valid_length.dtype.name, } elif 'roberta' in model_name or 'xlmr' in model_name: mx_out = model(token_ids, valid_length) shape_dict = { 'data0': token_ids.shape, 'data1': valid_length.shape, } dtype_dict = { 'data0': token_ids.dtype.name, 'data1': valid_length.dtype.name, } else: mx_out = model(token_ids, token_types, valid_length) shape_dict = { 'data0': token_ids.shape, 'data1': token_types.shape, 'data2': valid_length.shape } dtype_dict = { 'data0': token_ids.dtype.name, 'data1': token_types.dtype.name, 'data2': valid_length.dtype.name } sym = model._cached_graph[1] params = {} for k, v in model.collect_params().items(): params[v._var_name] = tvm.nd.array(v.data().asnumpy()) mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params) target = flags['target'] use_gpu = flags['use_gpu'] opt_level = flags['opt_level'] required_pass = flags['required_pass'] with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass): lib = relay.build(mod, target, params=params) if use_gpu: ctx = tvm.gpu() else: ctx = tvm.cpu() rt = graph_runtime.GraphModule(lib["default"](ctx)) if 'bart' in model_name: rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length) elif 'roberta' in model_name: rt.set_input(data0=token_ids, data1=valid_length) else: rt.set_input(data0=token_ids, data1=token_types, data2=valid_length) rt.run() for i in range(rt.get_num_outputs()): out = rt.get_output(i) if rt.get_num_outputs() == 1: mx_out_gt = mx_out.asnumpy() else: mx_out_gt = mx_out[i].asnumpy() npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=1e-3, atol=1e-1)
def compile_tvm_graph_runtime(model, model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type): key = (model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type) if key in _TVM_RT_CACHE: return _TVM_RT_CACHE[key] flags = get_ec2_tvm_flags()[instance_type] tvm = try_import_tvm() from tvm import relay from tvm.contrib import graph_runtime token_ids_shape = (batch_size, seq_length) if layout == 'NT' else (seq_length, batch_size) valid_length_shape = (batch_size,) if 'bart' in model_name: shape_dict = { 'data0': token_ids_shape, 'data1': valid_length_shape, 'data2': token_ids_shape, 'data3': valid_length_shape, } dtype_dict = { 'data0': 'int32', 'data1': 'int32', 'data2': 'int32', 'data3': 'int32', } elif 'roberta' in model_name or 'xlmr' in model_name: shape_dict = { 'data0': token_ids_shape, 'data1': valid_length_shape, } dtype_dict = { 'data0': 'int32', 'data1': 'int32', } else: shape_dict = { 'data0': token_ids_shape, 'data1': token_ids_shape, 'data2': valid_length_shape, } dtype_dict = { 'data0': 'int32', 'data1': 'int32', 'data2': 'int32' } sym = model._cached_graph[1] params = {} for k, v in model.collect_params().items(): params[v._var_name] = tvm.nd.array(v.data().asnumpy()) mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params) target = flags['target'] use_gpu = flags['use_gpu'] opt_level = flags['opt_level'] required_pass = flags['required_pass'] with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass): lib = relay.build(mod, target, params=params) if use_gpu: ctx = tvm.gpu() else: ctx = tvm.cpu() rt = graph_runtime.GraphModule(lib["default"](ctx)) _TVM_RT_CACHE[key] = rt return rt