Beispiel #1
0
def test_torch_tree_nn_save_load(dataset):
    X_train, X_test, y_train, y_test, vocab = dataset
    mod = TorchTreeNN(vocab, embed_dim=50, max_iter=100, embedding=None)
    mod.fit(X_train, y_train)
    mod.predict(X_test)
    with tempfile.NamedTemporaryFile(mode='wb') as f:
        name = f.name
        mod.to_pickle(name)
        mod2 = TorchTreeNN.from_pickle(name)
        mod2.predict(X_test)
        mod2.fit(X_test, y_test)
# In[43]:

torch_tree_nn_glove = TorchTreeNN(sst_glove_vocab,
                                  embedding=glove_embedding,
                                  embed_dim=50,
                                  max_iter=10,
                                  eta=0.05)

# In[44]:

get_ipython().run_line_magic('time',
                             '_ = torch_tree_nn_glove.fit(X_tree_train)')

# As with `TreeNN` above, you have the option of specifying the labels separately, and this is required if you are cross-validating the model using scikit-learn methods.

# In[45]:

torch_tree_glove_dev_predictions = torch_tree_nn_glove.predict(X_tree_dev)

# In[46]:

print(classification_report(y_tree_dev, torch_tree_glove_dev_predictions))

# ### Subtree supervision
#
# We've so far ignored one of the most exciting aspects of the SST: it has sentiment labels on every constituent from the root down to the lexical nodes.
#
# It is fairly easy to extend `TorchTreeNN` to learn from these additional labels. The key change is that the recursive interpretation function has to gather all of the node representations and their true labels and pass these to the loss function:
#
# <img src="fig/tree_nn_subtrees.png" width=600 />
Beispiel #3
0
def test_simple_example_params(dataset, param, expected):
    X_train, X_test, y_train, y_test, vocab = dataset
    model = TorchTreeNN(vocab, **{param: expected})
    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    assert accuracy_score(y_test, preds)