Exemple #1
0
def test_predict_functions_honor_device(dataset, func):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=2)
    mod.fit(X_train, y_train)
    prediction_func = getattr(mod, func)
    with pytest.raises(RuntimeError):
        prediction_func(X_test, device="FAKE_DEVICE")
Exemple #2
0
def test_model_graph_dimensions(dataset, mod_attr, dim, graph_attr):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=1)
    mod.fit(X_train, y_train)
    mod_attr_val = getattr(mod, mod_attr)
    graph_attr_val = getattr(mod.model, graph_attr).weight.shape[dim]
    assert mod_attr_val == graph_attr_val
Exemple #3
0
def test_build_tree_rep(tree, subtree_indices, emb_indices, n):
    tree = Tree.fromstring(tree)
    vocab = ["0", "1", "2", "3", "4", "5", "6", "$UNK"]
    mod = TorchTreeNN(vocab)
    result = mod._build_tree_rep(tree)
    assert result[0] == subtree_indices
    assert result[1] == emb_indices
    assert result[2] == n
Exemple #4
0
def test_parameter_setting(param, expected):
    vocab = []
    mod = TorchTreeNN(vocab)
    mod.set_params(**{param: expected})
    result = getattr(mod, param)
    if param == "embedding":
        assert np.array_equal(result, expected)
    else:
        assert result == expected
Exemple #5
0
def test_predict_restores_device(dataset, func):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, max_iter=2)
    mod.fit(X_train, y_train)
    current_device = mod.device
    assert current_device != torch.device("cpu:0")
    prediction_func = getattr(mod, func)
    prediction_func(X_test, device="cpu:0")
    assert mod.device == current_device
Exemple #6
0
def test_build_dataset(dataset, with_y, expected):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab)
    if with_y:
        dataset = mod.build_dataset(X_train, y_train)
    else:
        dataset = mod.build_dataset(X_train)
    result = next(iter(dataset))
    assert len(result) == expected
Exemple #7
0
def test_pretrained_embedding(dataset):
    X_train, X_test, y_train, y_test, vocab = dataset
    embed_dim = 5
    embedding = np.ones((len(vocab), embed_dim))
    mod = TorchTreeNN(vocab,
                      max_iter=1,
                      embedding=embedding,
                      freeze_embedding=True)
    mod.fit(X_train, y_train)
    graph_emb = mod.model.embedding.weight.detach().cpu().numpy()
    assert np.array_equal(embedding, graph_emb)
Exemple #8
0
def test_torch_tree_nn_save_load(dataset):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, embed_dim=50, max_iter=100, embedding=None)
    mod.fit(X_train, y_train)
    mod.predict(X_test)
    with tempfile.NamedTemporaryFile(mode='wb') as f:
        name = f.name
        mod.to_pickle(name)
        mod2 = TorchTreeNN.from_pickle(name)
        mod2.predict(X_test)
        mod2.fit(X_test, y_test)
# In[41]:

tree_glove_dev_predictions = tree_nn_glove.predict(X_tree_dev)

# In[42]:

print(classification_report(y_tree_dev, tree_glove_dev_predictions))

# ### PyTorch TreeNN implementation

# In[43]:

torch_tree_nn_glove = TorchTreeNN(sst_glove_vocab,
                                  embedding=glove_embedding,
                                  embed_dim=50,
                                  max_iter=10,
                                  eta=0.05)

# In[44]:

get_ipython().run_line_magic('time',
                             '_ = torch_tree_nn_glove.fit(X_tree_train)')

# As with `TreeNN` above, you have the option of specifying the labels separately, and this is required if you are cross-validating the model using scikit-learn methods.

# In[45]:

torch_tree_glove_dev_predictions = torch_tree_nn_glove.predict(X_tree_dev)

# In[46]:
Exemple #10
0
def test_simple_example_params(dataset, param, expected):
    X_train, X_test, y_train, y_test, vocab = dataset
    model = TorchTreeNN(vocab, **{param: expected})
    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    assert accuracy_score(y_test, preds)