def build_parser(DROPOUT, LSTM_NUM_LAYERS, word_to_ix, pretrained_embeds): # Predef TEST_EMBEDDING_DIM = 4 WORD_EMBEDDING_DIM = 64 STACK_EMBEDDING_DIM = 100 NUM_FEATURES = 3 # Build Model feat_extractor = SimpleFeatureExtractor() # BiLSTM word embeddings will probably work best, but feel free to experiment with the others you developed word_embedding_lookup = BiLSTMWordEmbedding(word_to_ix, WORD_EMBEDDING_DIM, STACK_EMBEDDING_DIM, num_layers=LSTM_NUM_LAYERS, dropout=DROPOUT) initialize_with_pretrained(pretrained_embeds, word_embedding_lookup) action_chooser = LSTMActionChooser(STACK_EMBEDDING_DIM * NUM_FEATURES, LSTM_NUM_LAYERS, dropout=DROPOUT) combiner = LSTMCombiner(STACK_EMBEDDING_DIM, num_layers=LSTM_NUM_LAYERS, dropout=DROPOUT) parser = TransitionParser(feat_extractor, word_embedding_lookup, action_chooser, combiner) return parser
def test_bilstm_word_embeds_d4_1(): global test_sent, word_to_ix, vocab torch.manual_seed(1) embedder = BiLSTMWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM, TEST_EMBEDDING_DIM, 1, 0.0) embeds = embedder(test_sent) assert len(embeds) == len(test_sent) assert isinstance(embeds, list) assert isinstance(embeds[0], ag.Variable) assert embeds[0].size() == (1, TEST_EMBEDDING_DIM) embeds_list = make_list(embeds) true = ( [0.09079286456108093, 0.06577987223863602, 0.26242679357528687, -0.004267544485628605], [0.16868481040000916, 0.2032647728919983, 0.23663431406021118, -0.11785736680030823], [0.35757705569267273, 0.3805052936077118, -0.006295515224337578, 0.0010524550452828407], [0.26692214608192444, 0.3241712749004364, 0.13473030924797058, -0.026079852133989334], [0.23157459497451782, 0.13698695600032806, 0.04000323265790939, 0.1107199415564537], [0.22783540189266205, -0.02211562544107437, 0.06239837780594826, 0.08553065359592438], [0.24633683264255524, 0.09283821284770966, 0.0987505242228508, -0.07646450400352478], [0.05530695244669914, -0.4060348570346832, -0.060150448232889175, -0.003920700401067734], [0.2099054455757141, -0.304738312959671, -0.01663055270910263, -0.05987118184566498] ) pairs = zip(embeds_list, true) check_tensor_correctness(pairs)
def test_bilstm_embedding_d4_1(): global test_doc torch.manual_seed(1) lstm = BiLSTMWordEmbedding(word_to_ix, TEST_EMBEDDING_DIM, LSTM_HIDDEN, LSTM_LAYERS, DROPOUT) pred_not = lstm(test_doc)[2].data.tolist()[0][:6] true_not = [ 0.11752596497535706, 0.042018793523311615, 0.06257987767457962, -0.057494595646858215, 0.06428981572389603, -0.16254858672618866 ] list_assert(pred_not, true_not)
def test_pretrain_embeddings_d4_5(): torch.manual_seed(1) word_to_ix = {"interest": 0, "rate": 1, "swap": 2} pretrained = { "interest": [6.1, 2.2, -3.5], "swap": [5.7, 1.6, 3.2], UNK_TOKEN: [8.5, -0.4, 2.0] } embedder = BiLSTMWordEmbedding(word_to_ix, 3, 2, 1, 0) initialize_with_pretrained(pretrained, embedder) embeddings = embedder.word_embeddings.weight.data pairs = [] pairs.append( (embeddings[word_to_ix["interest"]].tolist(), pretrained["interest"])) pairs.append( (embeddings[word_to_ix["rate"]].tolist(), pretrained[UNK_TOKEN])) pairs.append((embeddings[word_to_ix["swap"]].tolist(), pretrained["swap"])) check_tensor_correctness(pairs)