Пример #1
0

def test_list_backbone_names():
    assert len(list_backbone_names()) > 0


def tvm_enabled():
    try:
        tvm = try_import_tvm()
        return True
    except:
        return False


@pytest.mark.slow
@pytest.mark.parametrize('name', list_backbone_names())
def test_get_backbone(name, ctx):
    with tempfile.TemporaryDirectory() as root, ctx:
        model_cls, cfg, tokenizer, local_params_path, _ = get_backbone(
            name, root=root)
        net = model_cls.from_cfg(cfg)
        net.load_parameters(local_params_path)
        net.hybridize()
        num_params, num_fixed_params = count_parameters(net.collect_params())
        assert num_params > 0

        # Test for model export + save
        if 'gpt2' in name:
            pytest.skip('Skipping GPT-2 test')
        batch_size = 1
        sequence_length = 4
Пример #2
0
def test_list_backbone_names():
    assert len(list_backbone_names()) > 0
Пример #3
0
import json
import mxnet as mx
from gluonnlp.models import list_backbone_names, get_backbone

mx.npx.set_np()
batch_size = 1
sequence_length = 32
all_possible_ops = []
for name in list_backbone_names():
    model_cls, cfg, tokenizer, local_params_path, others = get_backbone(
        model_name=name)
    net = model_cls.from_cfg(cfg)
    net.initialize()
    net.hybridize()
    print('Save the architecture of {} to {}.json'.format(name, name))
    inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
    token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
    valid_length = mx.np.random.randint(1, 10, (batch_size, ))
    if 'roberta' in name or 'xlmr' in name:
        out = net(inputs, valid_length)
    else:
        out = net(inputs, token_types, valid_length)
    sym = net._cached_graph[1]
    sym.save('{}.json'.format(name), remove_amp_cast=True)
    all_ops = set()
    with open('{}.json'.format(name), 'r') as f:
        sym_info = json.load(f)
        for ele in sym_info['nodes']:
            all_ops.add(ele['op'])
    with open('{}_all_ops.json'.format(name), 'w') as f:
        json.dump(list(all_ops), f)