Example #1
0
def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx):
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    tvm_recommended_flags = get_ec2_tvm_flags()
    if ctx.device_type == 'gpu':
        flags = tvm_recommended_flags['g4']
    elif ctx.device_type == 'cpu':
        flags = tvm_recommended_flags['c4']
        if model_name != 'google_albert_base_v2':
            # Skip all other tests
            return
    else:
        raise NotImplementedError
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(
            model_name, root=root)
        cfg.defrost()
        cfg.MODEL.layout = layout
        cfg.freeze()
        model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path)
        model.hybridize()
        if layout == 'NT':
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (batch_size, seq_length),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (batch_size, seq_length),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        else:
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (seq_length, batch_size),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (seq_length, batch_size),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        if 'bart' in model_name:
            mx_out = model(token_ids, valid_length, token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
                'data2': token_ids.shape,
                'data3': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
                'data2': token_ids.dtype.name,
                'data3': valid_length.dtype.name,
            }
        elif 'roberta' in model_name or 'xlmr' in model_name:
            mx_out = model(token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
            }
        else:
            mx_out = model(token_ids, token_types, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': token_types.shape,
                'data2': valid_length.shape
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': token_types.dtype.name,
                'data2': valid_length.dtype.name
            }
        sym = model._cached_graph[1]
        params = {}
        for k, v in model.collect_params().items():
            params[v._var_name] = tvm.nd.array(v.data().asnumpy())
        mod, params = relay.frontend.from_mxnet(sym,
                                                shape=shape_dict,
                                                dtype=dtype_dict,
                                                arg_params=params)
        target = flags['target']
        use_gpu = flags['use_gpu']
        opt_level = flags['opt_level']
        required_pass = flags['required_pass']
        with tvm.transform.PassContext(opt_level=opt_level,
                                       required_pass=required_pass):
            lib = relay.build(mod, target, params=params)
        if use_gpu:
            ctx = tvm.gpu()
        else:
            ctx = tvm.cpu()
        rt = graph_runtime.GraphModule(lib["default"](ctx))
        if 'bart' in model_name:
            rt.set_input(data0=token_ids,
                         data1=valid_length,
                         data2=token_ids,
                         data3=valid_length)
        elif 'roberta' in model_name:
            rt.set_input(data0=token_ids, data1=valid_length)
        else:
            rt.set_input(data0=token_ids,
                         data1=token_types,
                         data2=valid_length)
        rt.run()
        for i in range(rt.get_num_outputs()):
            out = rt.get_output(i)
            if rt.get_num_outputs() == 1:
                mx_out_gt = mx_out.asnumpy()
            else:
                mx_out_gt = mx_out[i].asnumpy()
            npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=1e-3, atol=1e-1)
Example #2
0
def compile_tvm_graph_runtime(model, model_name, layout, compute_layout,
                              batch_size, seq_length, dtype, instance_type):
    key = (model_name, layout, compute_layout, batch_size, seq_length, dtype, instance_type)
    if key in _TVM_RT_CACHE:
        return _TVM_RT_CACHE[key]
    flags = get_ec2_tvm_flags()[instance_type]
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    token_ids_shape = (batch_size, seq_length) if layout == 'NT' else (seq_length, batch_size)
    valid_length_shape = (batch_size,)
    if 'bart' in model_name:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': valid_length_shape,
            'data2': token_ids_shape,
            'data3': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
            'data2': 'int32',
            'data3': 'int32',
        }
    elif 'roberta' in model_name or 'xlmr' in model_name:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
        }
    else:
        shape_dict = {
            'data0': token_ids_shape,
            'data1': token_ids_shape,
            'data2': valid_length_shape,
        }
        dtype_dict = {
            'data0': 'int32',
            'data1': 'int32',
            'data2': 'int32'
        }
    sym = model._cached_graph[1]
    params = {}
    for k, v in model.collect_params().items():
        params[v._var_name] = tvm.nd.array(v.data().asnumpy())
    mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
    target = flags['target']
    use_gpu = flags['use_gpu']
    opt_level = flags['opt_level']
    required_pass = flags['required_pass']
    with tvm.transform.PassContext(opt_level=opt_level, required_pass=required_pass):
        lib = relay.build(mod, target, params=params)
    if use_gpu:
        ctx = tvm.gpu()
    else:
        ctx = tvm.cpu()
    rt = graph_runtime.GraphModule(lib["default"](ctx))
    _TVM_RT_CACHE[key] = rt
    return rt