Ejemplo n.º 1
0
def test_get_cache_model_noncache_models():
    language_models_params = {'awd_lstm_lm_1150': 'awd_lstm_lm_1150_wikitext-2-45d6df33.params',
                              'awd_lstm_lm_600': 'awd_lstm_lm_600_wikitext-2-7894a046.params',
                              'standard_lstm_lm_200': 'standard_lstm_lm_200_wikitext-2-700b532d.params',
                              'standard_lstm_lm_650': 'standard_lstm_lm_650_wikitext-2-14041667.params',
                              'standard_lstm_lm_1500': 'standard_lstm_lm_1500_wikitext-2-d572ce71.params'}
    datasets = ['wikitext-2']
    for name in language_models_params.keys():
        for dataset_name in datasets:
            _, vocab = get_text_model(name=name, dataset_name=dataset_name, pretrained=True)
            ntokens = len(vocab)

            cache_cell_0 = get_cache_model(name, dataset_name, window=1, theta=0.6,
                                           lambdas=0.2, root='tests/data/model/')
            print(cache_cell_0)

            model, _ = get_text_model(name=name, dataset_name=dataset_name, pretrained=True)
            cache_cell_1 = CacheCell(model, ntokens, window=1, theta=0.6, lambdas=0.2)
            cache_cell_1.load_params('tests/data/model/' + language_models_params.get(name))
            print(cache_cell_1)

            outs0, word_history0, cache_history0, hidden0 = \
                cache_cell_0(mx.nd.arange(10).reshape(10, 1), mx.nd.arange(10).reshape(10, 1), None, None)
            outs1, word_history1, cache_history1, hidden1 = \
                cache_cell_1(mx.nd.arange(10).reshape(10, 1), mx.nd.arange(10).reshape(10, 1), None, None)

            assert outs0.shape == outs1.shape, outs0.shape
            assert len(word_history0) == len(word_history1), len(word_history0)
            assert len(cache_history0) == len(cache_history1), len(cache_history0)
            assert len(hidden0) == len(hidden1), len(hidden0)
Ejemplo n.º 2
0
def test_text_models():
    val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2')
    val_freq = get_frequencies(val)
    vocab = nlp.Vocab(val_freq)
    text_models = [
        'standard_lstm_lm_200', 'standard_lstm_lm_650',
        'standard_lstm_lm_1500', 'awd_lstm_lm_1150', 'awd_lstm_lm_600'
    ]
    pretrained_to_test = {
        'standard_lstm_lm_1500': 'wikitext-2',
        'standard_lstm_lm_650': 'wikitext-2',
        'standard_lstm_lm_200': 'wikitext-2',
        'awd_lstm_lm_1150': 'wikitext-2',
        'awd_lstm_lm_600': 'wikitext-2'
    }

    for model_name in text_models:
        eprint('testing forward for %s' % model_name)
        pretrained_dataset = pretrained_to_test.get(model_name)
        model, _ = get_text_model(model_name,
                                  vocab=vocab,
                                  dataset_name=pretrained_dataset,
                                  pretrained=pretrained_dataset is not None,
                                  root='tests/data/model/')

        print(model)
        if not pretrained_dataset:
            model.collect_params().initialize()
        output, state = model(mx.nd.arange(330).reshape(33, 10))
        output.wait_to_read()
Ejemplo n.º 3
0
def _test_pretrained_big_text_models():
    text_models = ['big_rnn_lm_2048_512']
    pretrained_to_test = {'big_rnn_lm_2048_512': 'gbw'}

    for model_name in text_models:
        eprint('testing forward for %s' % model_name)
        pretrained_dataset = pretrained_to_test.get(model_name)
        model, _ = get_text_model(model_name, dataset_name=pretrained_dataset,
                                  pretrained=True, root='tests/data/model/')

        print(model)
        batch_size = 10
        hidden = model.begin_state(batch_size=batch_size, func=mx.nd.zeros)
        output, state = model(mx.nd.arange(330).reshape((33, 10)), hidden)
        output.wait_to_read()
Ejemplo n.º 4
0
def test_big_text_models():
    # use a small vocabulary for testing
    val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2')
    val_freq = get_frequencies(val)
    vocab = nlp.Vocab(val_freq)
    text_models = ['big_rnn_lm_2048_512']

    for model_name in text_models:
        eprint('testing forward for %s' % model_name)
        model, _ = get_text_model(model_name, vocab=vocab, root='tests/data/model/')

        print(model)
        model.collect_params().initialize()
        batch_size = 10
        hidden = model.begin_state(batch_size=batch_size, func=mx.nd.zeros)
        output, state = model(mx.nd.arange(330).reshape((33, 10)), hidden)
        output.wait_to_read()
Ejemplo n.º 5
0
def test_text_models():
    val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2')
    val_freq = get_frequencies(val)
    vocab = nlp.Vocab(val_freq)
    text_models = ['standard_lstm_lm_200', 'standard_lstm_lm_650', 'standard_lstm_lm_1500', 'awd_lstm_lm_1150', 'awd_lstm_lm_600']
    pretrained_to_test = {'standard_lstm_lm_1500': 'wikitext-2', 'standard_lstm_lm_650': 'wikitext-2', 'standard_lstm_lm_200': 'wikitext-2', 'awd_lstm_lm_1150': 'wikitext-2', 'awd_lstm_lm_600': 'wikitext-2'}

    for model_name in text_models:
        eprint('testing forward for %s' % model_name)
        pretrained_dataset = pretrained_to_test.get(model_name)
        model, _ = get_text_model(model_name, vocab=vocab, dataset_name=pretrained_dataset,
                                  pretrained=pretrained_dataset is not None, root='tests/data/model/')

        print(model)
        if not pretrained_dataset:
            model.collect_params().initialize()
        output, state = model(mx.nd.arange(330).reshape(33, 10))
        output.wait_to_read()