Esempio n. 1
0
def test_get_backbone(name, ctx):
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(
            name, root=root)
        net = model_cls.from_cfg(cfg)
        net.load_parameters(local_params_path)
        net.hybridize()
        num_params, num_fixed_params = count_parameters(net.collect_params())
        assert num_params > 0

        # Test for model export + save
        batch_size = 1
        sequence_length = 4
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(1, sequence_length, (batch_size, ))
        if 'roberta' in name:
            out = net(inputs, valid_length)
        elif 'xlmr' in name:
            # Skip for XLMR tests. It takes too much CPU memory.
            return
        elif 'bart' in name:
            out = net(inputs, valid_length, inputs, valid_length)
        else:
            out = net(inputs, token_types, valid_length)
        mx.npx.waitall()
        net.export(os.path.join(root, 'model'))
Esempio n. 2
0
def test_get_backbone(name, ctx):
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(
            name, root=root)
        net = model_cls.from_cfg(cfg)
        net.load_parameters(local_params_path)
        net.hybridize()
        num_params, num_fixed_params = count_parameters(net.collect_params())
        assert num_params > 0

        # Test for model export + save
        if 'gpt2' in name:
            pytest.skip('Skipping GPT-2 test')
        batch_size = 1
        sequence_length = 4
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(1, sequence_length, (batch_size, ))
        if 'roberta' in name:
            out = net(inputs, valid_length)
        elif 'xlmr' in name:
            out = net(inputs, valid_length)
        elif 'bart' in name:
            out = net(inputs, valid_length, inputs, valid_length)
        elif 'gpt2' in name:
            states = net.init_states(batch_size=batch_size, ctx=ctx)
            out, new_states = net(inputs, states)
            out_np = out.asnumpy()
        else:
            out = net(inputs, token_types, valid_length)
        mx.npx.waitall()
        net.export(os.path.join(root, 'model'))
Esempio n. 3
0
def get_network(model_name,
                ctx_l,
                dropout=0.1,
                checkpoint_path=None,
                backbone_path=None,
                dtype='float32'):
    """
    Get the network that fine-tune the Question Answering Task

    Parameters
    ----------
    model_name : str
        The model name of the backbone model
    ctx_l :
        Context list of training device like [mx.gpu(0), mx.gpu(1)]
    dropout : float
        Dropout probability of the task specified layer
    checkpoint_path: str
        Path to a Fine-tuned checkpoint
    backbone_path: str
        Path to the backbone model to be loaded in qa_net

    Returns
    -------
    cfg
    tokenizer
    qa_net
    use_segmentation
    """
    # Create the network
    use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
    Model, cfg, tokenizer, download_params_path, _ = \
        get_backbone(model_name, load_backbone=not backbone_path)
    backbone = Model.from_cfg(cfg, use_pooler=False, dtype=dtype)
    # Load local backbone parameters if backbone_path provided.
    # Otherwise, download backbone parameters from gluon zoo.

    backbone_params_path = backbone_path if backbone_path else download_params_path
    if checkpoint_path is None:
        backbone.load_parameters(backbone_params_path, ignore_extra=True,
                                 ctx=ctx_l, cast_dtype=True)
        num_params, num_fixed_params = count_parameters(backbone.collect_params())
        logging.info(
            'Loading Backbone Model from {}, with total/fixd parameters={}/{}'.format(
                backbone_params_path, num_params, num_fixed_params))
    qa_net = ModelForQAConditionalV1(backbone=backbone,
                                     dropout_prob=dropout,
                                     use_segmentation=use_segmentation,
                                     weight_initializer=TruncNorm(stdev=0.02))
    if checkpoint_path is None:
        # Ignore the UserWarning during initialization,
        # There is no need to re-initialize the parameters of backbone
        qa_net.initialize(ctx=ctx_l)
    else:
        qa_net.load_parameters(checkpoint_path, ctx=ctx_l, cast_dtype=True)
    qa_net.hybridize()

    return cfg, tokenizer, qa_net, use_segmentation
Esempio n. 4
0
def main(args):
    num_process = min(multiprocessing.cpu_count(), args.num_process)
    _, cfg, tokenizer, _, _ = \
        get_backbone(args.model_name, load_backbone=False)

    fnames = sorted(os.listdir(args.input))
    fnames = [os.path.join(args.input, fname) for fname in fnames]
    if args.shuffle:
        random.shuffle(fnames)
    num_files = len(fnames)
    num_out_files = min(args.num_out_files, num_files)
    splited_files = np.array_split(fnames, num_out_files)
    output_files = [
        os.path.join(args.output,
                     "owt-pretrain-record-{}.npz".format(str(i).zfill(4)))
        for i in range(num_out_files)
    ]
    print("All preprocessed features will be saved in {} npz files".format(
        num_out_files))
    if not os.path.exists(args.output):
        os.makedirs(args.output, exist_ok=True)
    num_process = min(num_process, num_out_files)
    print('Start preprocessing {} text files with {} cores'.format(
        num_files, num_process))
    process_args = [(splited_files[i], output_files[i], tokenizer,
                     args.max_seq_length, args.short_seq_prob)
                    for i in range(num_out_files)]
    start_time = time.time()
    with multiprocessing.Pool(num_process) as pool:
        iter = pool.imap(get_all_features, process_args)
        fea_written = 0
        f_read = 0
        for i, np_features in enumerate(iter):
            elapsed = time.time() - start_time
            fea_written += len(np_features[0])
            f_read += len(splited_files[i])
            print(
                "Processed {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format(
                    fea_written, elapsed,
                    (num_files - f_read) / (f_read / elapsed)))
    print("Done processing within {:.2f} seconds".format(elapsed))
Esempio n. 5
0
def get_network(model_name,
                ctx_l,
                checkpoint_path=None,
                backbone_path=None,
                task=None):
    """
    Get the network that fine-tune the Question Answering Task
    """

    use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
    Model, cfg, tokenizer, download_params_path, _ = \
        get_backbone(model_name, load_backbone=not backbone_path)
    backbone = Model.from_cfg(cfg)
    # Load local backbone parameters if backbone_path provided.
    # Otherwise, download backbone parameters from gluon zoo.

    backbone_params_path = backbone_path if backbone_path else download_params_path
    if checkpoint_path is None:
        backbone.load_parameters(backbone_params_path,
                                 ignore_extra=True,
                                 ctx=ctx_l,
                                 cast_dtype=True)
        num_params, num_fixed_params \
            = count_parameters(deduplicate_param_dict(backbone.collect_params()))
        logging.info(
            'Loading Backbone Model from {}, with total/fixd parameters={}/{}'.
            format(backbone_params_path, num_params, num_fixed_params))
    classify_net = TextPredictionNet(backbone, task.class_num)
    if checkpoint_path is None:
        # Ignore the UserWarning during initialization,
        # There is no need to re-initialize the parameters of backbone
        classify_net.initialize(ctx=ctx_l)
    else:
        classify_net.load_parameters(checkpoint_path,
                                     ctx=ctx_l,
                                     cast_dtype=True)
    classify_net.hybridize()

    return cfg, tokenizer, classify_net, use_segmentation
Esempio n. 6
0
def test_tvm_integration(model_name, batch_size, seq_length, layout, ctx):
    tvm = try_import_tvm()
    from tvm import relay
    from tvm.contrib import graph_runtime
    tvm_recommended_flags = get_ec2_tvm_flags()
    if ctx.device_type == 'gpu':
        flags = tvm_recommended_flags['g4']
    elif ctx.device_type == 'cpu':
        flags = tvm_recommended_flags['c4']
        if model_name != 'google_albert_base_v2':
            # Skip all other tests
            return
    else:
        raise NotImplementedError
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(
            model_name, root=root)
        cfg.defrost()
        cfg.MODEL.layout = layout
        cfg.freeze()
        model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path)
        model.hybridize()
        if layout == 'NT':
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (batch_size, seq_length),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (batch_size, seq_length),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        else:
            token_ids = mx.np.random.randint(0,
                                             cfg.MODEL.vocab_size,
                                             (seq_length, batch_size),
                                             dtype=np.int32)
            token_types = mx.np.random.randint(0,
                                               2, (seq_length, batch_size),
                                               dtype=np.int32)
            valid_length = mx.np.random.randint(seq_length // 2,
                                                seq_length, (batch_size, ),
                                                dtype=np.int32)
        if 'bart' in model_name:
            mx_out = model(token_ids, valid_length, token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
                'data2': token_ids.shape,
                'data3': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
                'data2': token_ids.dtype.name,
                'data3': valid_length.dtype.name,
            }
        elif 'roberta' in model_name or 'xlmr' in model_name:
            mx_out = model(token_ids, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': valid_length.shape,
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': valid_length.dtype.name,
            }
        else:
            mx_out = model(token_ids, token_types, valid_length)
            shape_dict = {
                'data0': token_ids.shape,
                'data1': token_types.shape,
                'data2': valid_length.shape
            }
            dtype_dict = {
                'data0': token_ids.dtype.name,
                'data1': token_types.dtype.name,
                'data2': valid_length.dtype.name
            }
        sym = model._cached_graph[1]
        params = {}
        for k, v in model.collect_params().items():
            params[v._var_name] = tvm.nd.array(v.data().asnumpy())
        mod, params = relay.frontend.from_mxnet(sym,
                                                shape=shape_dict,
                                                dtype=dtype_dict,
                                                arg_params=params)
        target = flags['target']
        use_gpu = flags['use_gpu']
        opt_level = flags['opt_level']
        required_pass = flags['required_pass']
        with tvm.transform.PassContext(opt_level=opt_level,
                                       required_pass=required_pass):
            lib = relay.build(mod, target, params=params)
        if use_gpu:
            ctx = tvm.gpu()
        else:
            ctx = tvm.cpu()
        rt = graph_runtime.GraphModule(lib["default"](ctx))
        if 'bart' in model_name:
            rt.set_input(data0=token_ids,
                         data1=valid_length,
                         data2=token_ids,
                         data3=valid_length)
        elif 'roberta' in model_name:
            rt.set_input(data0=token_ids, data1=valid_length)
        else:
            rt.set_input(data0=token_ids,
                         data1=token_types,
                         data2=valid_length)
        rt.run()
        for i in range(rt.get_num_outputs()):
            out = rt.get_output(i)
            if rt.get_num_outputs() == 1:
                mx_out_gt = mx_out.asnumpy()
            else:
                mx_out_gt = mx_out[i].asnumpy()
            npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=1e-3, atol=1e-1)
Esempio n. 7
0
    def _train_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_fp16:
            from mxnet import amp
            amp.init()

        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True)
        else:
            model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path, ctx=ctx)
        model.hybridize(static_alloc=True)
        vocab_size = cfg.MODEL.vocab_size
        if hasattr(cfg.MODEL, 'units'):
            out_units = cfg.MODEL.units
        else:
            out_units = cfg.MODEL.DECODER.units
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, sequence_length, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(
                0, 1, (batch_size, out_units), dtype=np.float32, ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
            contextual_embedding_ograd = mxnet.np.random.normal(
                0, 1, (sequence_length, batch_size, out_units),
                dtype=np.float32, ctx=ctx)
            pooled_out_ograd = mxnet.np.random.normal(0, 1, (batch_size, out_units),
                                                      dtype=np.float32,
                                                      ctx=ctx)
        else:
            raise NotImplementedError
        if model_cls.__name__ in ['BertModel', 'AlbertModel', 'ElectraModel', 'MobileBertModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, token_types, valid_length)
                    # We'd like to set the head gradient of
                    # contextual_embedding to contextual_embedding_ograd
                    # and the head gradient of pooled_out to pooled_out_ograd
                    # Thus, we simply doing two hadamard product and sum up the results.
                    fake_loss = mxnet.np.sum(contextual_embedding * contextual_embedding_ograd)\
                                + mxnet.np.sum(pooled_out * pooled_out_ograd)
                    fake_loss.backward()
                mxnet.npx.waitall()
        elif model_cls.__name__ in ['BartModel']:
            def train_step():
                with mxnet.autograd.record():
                    contextual_embedding, pooled_out = model(input_ids, valid_length,
                                                             input_ids, valid_length)
                    fake_loss = (contextual_embedding * contextual_embedding_ograd).sum() \
                                + (pooled_out * pooled_out_ograd).sum()
                    fake_loss.backward()
                mxnet.npx.waitall()
        else:
            raise NotImplementedError
        timeit.repeat(train_step, repeat=1, number=5)
        mxnet.npx.waitall()
        runtimes = timeit.repeat(train_step, repeat=self._repeat, number=3)
        mxnet.npx.waitall()
        ctx.empty_cache()
        mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            train_step()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(train_step)
            memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory
Esempio n. 8
0
    def _inference_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_fp16:
            os.environ['export MXNET_FC_TRUE_FP16'] = '1'
            dtype = 'float16'
        else:
            dtype = 'float32'
        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True, dtype=dtype)
        else:
            model = model_cls.from_cfg(cfg, dtype=dtype)
        model.load_parameters(backbone_param_path, ctx=ctx, cast_dtype=True)
        model.cast(dtype)
        model.hybridize(static_alloc=True, static_shape=True)
        vocab_size = cfg.MODEL.vocab_size
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0, vocab_size, (batch_size, sequence_length),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0, vocab_size, (sequence_length, batch_size),
                                                dtype=np.int32, ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size), dtype=np.int32, ctx=ctx)
            valid_length = mxnet.np.full((batch_size,), sequence_length,
                                         dtype=np.int32, ctx=ctx)
        else:
            raise NotImplementedError
        mxnet.npx.waitall()

        def run_forward():
            if 'roberta' in model_name or 'xlmr' in model_name:
                out = model(input_ids, valid_length)
            elif 'bart' in model_name:
                out = model(input_ids, valid_length, input_ids, valid_length)
            else:
                out = model(input_ids, token_types, valid_length)
            if isinstance(out, list):
                for ele in out:
                    ele.wait_to_read()
            else:
                out.wait_to_read()

        if self._use_tvm:
            tvm = try_import_tvm()
            run_forward()
            if self._use_gpu:
                ctx = tvm.gpu()
            else:
                ctx = tvm.cpu()
            rt = compile_tvm_graph_runtime(model=model, model_name=model_name,
                                           layout=self._layout, compute_layout=self._compute_layout,
                                           batch_size=batch_size, seq_length=sequence_length,
                                           instance_type=self._instance_type,
                                           dtype='float32' if not self._use_fp16 else 'float16')
            tvm_input_ids = tvm.nd.array(input_ids.asnumpy(), ctx=ctx)
            tvm_token_types = tvm.nd.array(token_types.asnumpy(), ctx=ctx)
            tvm_valid_length = tvm.nd.array(valid_length.asnumpy(), ctx=ctx)

            def run_tvm_forward():
                if 'roberta' in model_name or 'xlmr' in model_name:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_valid_length)
                elif 'bart' in model_name:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_valid_length)
                else:
                    rt.set_input(data0=tvm_input_ids, data1=tvm_token_types,
                                 data2=tvm_valid_length)
                rt.run()
                for i in range(rt.get_num_outputs()):
                    out = rt.get_output(i)
            # Warmup
            timeit.repeat(run_tvm_forward, repeat=1, number=2)
            runtimes = timeit.repeat(run_tvm_forward, repeat=self._repeat, number=3)
        else:
            timeit.repeat(run_forward, repeat=1, number=3)
            runtimes = timeit.repeat(run_forward, repeat=self._repeat, number=3)
            mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            run_forward()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(run_forward)
            memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory
Esempio n. 9
0
    def _inference_speed_memory(self, model_name: str, batch_size: int, sequence_length: int)\
            -> Tuple[float, Memory]:
        if self._use_gpu:
            ctx = mxnet.gpu()
        else:
            ctx = mxnet.cpu()
        model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(
            model_name)
        # TODO Support fp16 profiling
        cfg.defrost()
        cfg.MODEL.layout = self._layout
        if model_cls.__name__ not in ['BartModel']:
            cfg.MODEL.compute_layout = self._compute_layout
        cfg.freeze()
        if model_cls.__name__ in ['BartModel']:
            model = model_cls.from_cfg(cfg, extract_feature=True)
        else:
            model = model_cls.from_cfg(cfg)
        model.load_parameters(backbone_param_path, ctx=ctx)
        model.hybridize()
        vocab_size = cfg.MODEL.vocab_size
        if self._layout == 'NT':
            input_ids = mxnet.np.random.randint(0,
                                                vocab_size,
                                                (batch_size, sequence_length),
                                                dtype=np.int32,
                                                ctx=ctx)
            token_types = mxnet.np.zeros((batch_size, sequence_length),
                                         dtype=np.int32,
                                         ctx=ctx)
            valid_length = mxnet.np.full((batch_size, ),
                                         sequence_length,
                                         dtype=np.int32,
                                         ctx=ctx)
        elif self._layout == 'TN':
            input_ids = mxnet.np.random.randint(0,
                                                vocab_size,
                                                (sequence_length, batch_size),
                                                dtype=np.int32,
                                                ctx=ctx)
            token_types = mxnet.np.zeros((sequence_length, batch_size),
                                         dtype=np.int32,
                                         ctx=ctx)
            valid_length = mxnet.np.full((batch_size, ),
                                         sequence_length,
                                         dtype=np.int32,
                                         ctx=ctx)
        else:
            raise NotImplementedError
        mxnet.npx.waitall()

        def run_forward():
            if 'roberta' in model_name or 'xlmr' in model_name:
                out = model(input_ids, valid_length)
            elif 'bart' in model_name:
                out = model(input_ids, valid_length, input_ids, valid_length)
            else:
                out = model(input_ids, token_types, valid_length)
            if isinstance(out, list):
                for ele in out:
                    ele.wait_to_read()
            else:
                out.wait_to_read()

        timeit.repeat(run_forward, repeat=1, number=3)
        runtimes = timeit.repeat(run_forward, repeat=self._repeat, number=3)
        mxnet.npx.waitall()
        # Profile memory
        if self._use_gpu:
            nvml.nvmlInit()
            run_forward()
            mxnet.npx.waitall()
            handle = nvml.nvmlDeviceGetHandleByIndex(self._device_idx)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            # cpu
            memory_bytes = measure_peak_memory_cpu(run_forward)
            memory = Memory(memory_bytes) if isinstance(memory_bytes,
                                                        int) else memory_bytes
        return float(np.min(runtimes) / 3.0), memory
import json
import mxnet as mx
from gluonnlp.models import list_backbone_names, get_backbone

mx.npx.set_np()
batch_size = 1
sequence_length = 32
all_possible_ops = []
for name in list_backbone_names():
    model_cls, cfg, tokenizer, local_params_path, others = get_backbone(
        model_name=name)
    net = model_cls.from_cfg(cfg)
    net.initialize()
    net.hybridize()
    print('Save the architecture of {} to {}.json'.format(name, name))
    inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
    token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
    valid_length = mx.np.random.randint(1, 10, (batch_size, ))
    if 'roberta' in name or 'xlmr' in name:
        out = net(inputs, valid_length)
    else:
        out = net(inputs, token_types, valid_length)
    sym = net._cached_graph[1]
    sym.save('{}.json'.format(name), remove_amp_cast=True)
    all_ops = set()
    with open('{}.json'.format(name), 'r') as f:
        sym_info = json.load(f)
        for ele in sym_info['nodes']:
            all_ops.add(ele['op'])
    with open('{}_all_ops.json'.format(name), 'w') as f:
        json.dump(list(all_ops), f)