def test_d3_3_sample(): global c2i, i2c model_fname = "deliverable_3.2.mod" # does the file exist? assert_true(os.path.exists(model_fname)) # try loading saved model trained_model = torch.load(model_fname) # try making a new town: new_town = lm.sample(trained_model, c2i, i2c) # verify that it's not too long: assert_less(len(new_town), 200) # make sure that it doesn't start with BOS_SYM assert_false(new_town[0] == vocab.BOS_SYM) # verify that its probability is lower than gibberish: gibberish_string = "asdflkjiopqutepoiuqrfm" p_gibberish = lm.compute_prob(trained_model, gibberish_string, c2i) / len(gibberish_string) p_good = lm.compute_prob(trained_model, new_town, c2i) / len(new_town) # negative log probs, so bigger is smaller assert_greater(p_gibberish, p_good)
def test_d3_2_training(): global corpus, c2i model_fname = "deliverable_3.2.mod" # does the file exist? assert_true(os.path.exists(model_fname)) # try loading saved model trained_model = torch.load(model_fname) # was it made correctly (as per the assignment's instructions)? eq_(trained_model.input_lookup.num_embeddings, len(c2i)) eq_(trained_model.input_lookup.embedding_dim, 25) eq_(trained_model.lstm.num_layers, 1) eq_(trained_model.lstm.hidden_size, 50) # ask it for gibberish prob, make sure that's lower than actual town name prob gibberish_string = "asdflkjiopqutepoiuqrfm" good_fake_string = "Little Brockton-upon-Thyme" p_gibberish = lm.compute_prob(trained_model, gibberish_string, c2i) p_good = lm.compute_prob(trained_model, good_fake_string, c2i) p_gibberish_norm = p_gibberish / len(gibberish_string) p_good_norm = p_good / len(good_fake_string) # p_gibberish should be larger than p_good, since we're in negative log space assert_greater(p_gibberish_norm, p_good_norm)