def test_torch_tree_nn_incremental(X_tree):
    X, vocab = X_tree
    model = torch_tree_nn.TorchTreeNN(vocab,
                                      embed_dim=50,
                                      hidden_dim=50,
                                      max_iter=100,
                                      embedding=None)
    model.fit(X, X_dev=X, dev_iter=20)
    epochs = list(model.dev_predictions.keys())
    assert epochs == list(range(20, 101, 20))
    assert all(len(v) == len(X) for v in model.dev_predictions.values())
def test_torch_tree_nn_save_load(X_tree):
    X, vocab = X_tree
    mod = torch_tree_nn.TorchTreeNN(vocab,
                                    embed_dim=50,
                                    hidden_dim=50,
                                    max_iter=100,
                                    embedding=None)
    mod.fit(X)
    mod.predict(X)
    with tempfile.NamedTemporaryFile(mode='wb') as f:
        name = f.name
        mod.to_pickle(name)
        mod2 = torch_tree_nn.TorchTreeNN.from_pickle(name)
        mod2.predict(X)
        mod2.fit(X)
Exemplo n.º 3
0
     {
         'hidden_dim': 10,
         'eta': 1.0,
         'max_iter': 100,
         'l2_strength': 0.01,
         'embed_dim': 100,
         'bidirectional': False
     }
 ],
 [
     np_tree_nn.TreeNN(
         vocab=[], max_iter=10, hidden_dim=5, eta=0.1),
     {'embed_dim': 5, 'hidden_dim': 10, 'eta': 1.0, 'max_iter': 100}
 ],
 [
     torch_tree_nn.TorchTreeNN(
         vocab=[], max_iter=10, hidden_dim=5, eta=0.1),
     {
         'embed_dim': 5,
         'hidden_dim': 10,
         'hidden_activation': nn.ReLU(),
         'eta': 1.0,
         'max_iter': 100,
         'l2_strength': 0.01
     }
 ],
 [
     np_shallow_neural_classifier.ShallowNeuralClassifier(
         hidden_dim=5, max_iter=1, eta=1.0),
     {
         'hidden_dim': 10,
         # Reset to ReLU: