def test_get_cache_model_noncache_models(): language_models_params = {'awd_lstm_lm_1150': 'awd_lstm_lm_1150_wikitext-2-45d6df33.params', 'awd_lstm_lm_600': 'awd_lstm_lm_600_wikitext-2-7894a046.params', 'standard_lstm_lm_200': 'standard_lstm_lm_200_wikitext-2-700b532d.params', 'standard_lstm_lm_650': 'standard_lstm_lm_650_wikitext-2-14041667.params', 'standard_lstm_lm_1500': 'standard_lstm_lm_1500_wikitext-2-d572ce71.params'} datasets = ['wikitext-2'] for name in language_models_params.keys(): for dataset_name in datasets: _, vocab = get_text_model(name=name, dataset_name=dataset_name, pretrained=True) ntokens = len(vocab) cache_cell_0 = get_cache_model(name, dataset_name, window=1, theta=0.6, lambdas=0.2, root='tests/data/model/') print(cache_cell_0) model, _ = get_text_model(name=name, dataset_name=dataset_name, pretrained=True) cache_cell_1 = CacheCell(model, ntokens, window=1, theta=0.6, lambdas=0.2) cache_cell_1.load_params('tests/data/model/' + language_models_params.get(name)) print(cache_cell_1) outs0, word_history0, cache_history0, hidden0 = \ cache_cell_0(mx.nd.arange(10).reshape(10, 1), mx.nd.arange(10).reshape(10, 1), None, None) outs1, word_history1, cache_history1, hidden1 = \ cache_cell_1(mx.nd.arange(10).reshape(10, 1), mx.nd.arange(10).reshape(10, 1), None, None) assert outs0.shape == outs1.shape, outs0.shape assert len(word_history0) == len(word_history1), len(word_history0) assert len(cache_history0) == len(cache_history1), len(cache_history0) assert len(hidden0) == len(hidden1), len(hidden0)
def test_text_models(): val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2') val_freq = get_frequencies(val) vocab = nlp.Vocab(val_freq) text_models = [ 'standard_lstm_lm_200', 'standard_lstm_lm_650', 'standard_lstm_lm_1500', 'awd_lstm_lm_1150', 'awd_lstm_lm_600' ] pretrained_to_test = { 'standard_lstm_lm_1500': 'wikitext-2', 'standard_lstm_lm_650': 'wikitext-2', 'standard_lstm_lm_200': 'wikitext-2', 'awd_lstm_lm_1150': 'wikitext-2', 'awd_lstm_lm_600': 'wikitext-2' } for model_name in text_models: eprint('testing forward for %s' % model_name) pretrained_dataset = pretrained_to_test.get(model_name) model, _ = get_text_model(model_name, vocab=vocab, dataset_name=pretrained_dataset, pretrained=pretrained_dataset is not None, root='tests/data/model/') print(model) if not pretrained_dataset: model.collect_params().initialize() output, state = model(mx.nd.arange(330).reshape(33, 10)) output.wait_to_read()
def _test_pretrained_big_text_models(): text_models = ['big_rnn_lm_2048_512'] pretrained_to_test = {'big_rnn_lm_2048_512': 'gbw'} for model_name in text_models: eprint('testing forward for %s' % model_name) pretrained_dataset = pretrained_to_test.get(model_name) model, _ = get_text_model(model_name, dataset_name=pretrained_dataset, pretrained=True, root='tests/data/model/') print(model) batch_size = 10 hidden = model.begin_state(batch_size=batch_size, func=mx.nd.zeros) output, state = model(mx.nd.arange(330).reshape((33, 10)), hidden) output.wait_to_read()
def test_big_text_models(): # use a small vocabulary for testing val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2') val_freq = get_frequencies(val) vocab = nlp.Vocab(val_freq) text_models = ['big_rnn_lm_2048_512'] for model_name in text_models: eprint('testing forward for %s' % model_name) model, _ = get_text_model(model_name, vocab=vocab, root='tests/data/model/') print(model) model.collect_params().initialize() batch_size = 10 hidden = model.begin_state(batch_size=batch_size, func=mx.nd.zeros) output, state = model(mx.nd.arange(330).reshape((33, 10)), hidden) output.wait_to_read()
def test_text_models(): val = nlp.data.WikiText2(segment='val', root='tests/data/wikitext-2') val_freq = get_frequencies(val) vocab = nlp.Vocab(val_freq) text_models = ['standard_lstm_lm_200', 'standard_lstm_lm_650', 'standard_lstm_lm_1500', 'awd_lstm_lm_1150', 'awd_lstm_lm_600'] pretrained_to_test = {'standard_lstm_lm_1500': 'wikitext-2', 'standard_lstm_lm_650': 'wikitext-2', 'standard_lstm_lm_200': 'wikitext-2', 'awd_lstm_lm_1150': 'wikitext-2', 'awd_lstm_lm_600': 'wikitext-2'} for model_name in text_models: eprint('testing forward for %s' % model_name) pretrained_dataset = pretrained_to_test.get(model_name) model, _ = get_text_model(model_name, vocab=vocab, dataset_name=pretrained_dataset, pretrained=pretrained_dataset is not None, root='tests/data/model/') print(model) if not pretrained_dataset: model.collect_params().initialize() output, state = model(mx.nd.arange(330).reshape(33, 10)) output.wait_to_read()