Exemple #1
0
def test_d3_3_sample():
    global c2i, i2c

    model_fname = "deliverable_3.2.mod"

    # does the file exist?
    assert_true(os.path.exists(model_fname))

    # try loading saved model
    trained_model = torch.load(model_fname)

    # try making a new town:
    new_town = lm.sample(trained_model, c2i, i2c)

    # verify that it's not too long:
    assert_less(len(new_town), 200)

    # make sure that it doesn't start with BOS_SYM
    assert_false(new_town[0] == vocab.BOS_SYM)

    # verify that its probability is lower than gibberish:
    gibberish_string = "asdflkjiopqutepoiuqrfm"
    p_gibberish = lm.compute_prob(trained_model, gibberish_string,
                                  c2i) / len(gibberish_string)
    p_good = lm.compute_prob(trained_model, new_town, c2i) / len(new_town)

    # negative log probs, so bigger is smaller
    assert_greater(p_gibberish, p_good)
Exemple #2
0
def test_d3_2_training():
    global corpus, c2i

    model_fname = "deliverable_3.2.mod"

    # does the file exist?
    assert_true(os.path.exists(model_fname))

    # try loading saved model
    trained_model = torch.load(model_fname)

    # was it made correctly (as per the assignment's instructions)?
    eq_(trained_model.input_lookup.num_embeddings, len(c2i))
    eq_(trained_model.input_lookup.embedding_dim, 25)
    eq_(trained_model.lstm.num_layers, 1)
    eq_(trained_model.lstm.hidden_size, 50)

    # ask it for gibberish prob, make sure that's lower than actual town name prob
    gibberish_string = "asdflkjiopqutepoiuqrfm"
    good_fake_string = "Little Brockton-upon-Thyme"

    p_gibberish = lm.compute_prob(trained_model, gibberish_string, c2i)
    p_good = lm.compute_prob(trained_model, good_fake_string, c2i)

    p_gibberish_norm = p_gibberish / len(gibberish_string)
    p_good_norm = p_good / len(good_fake_string)

    # p_gibberish should be larger than p_good, since we're in negative log space
    assert_greater(p_gibberish_norm, p_good_norm)