示例#1
0
def test_conve_bce_combo():
    # no exception
    model = ConvE(loss='bce')

    # no exception
    model = TransE(loss='nll')

    # Invalid combination. Hence exception.
    with pytest.raises(ValueError):
        model = TransE(loss='bce')

    # Invalid combination. Hence exception.
    with pytest.raises(ValueError):
        model = ConvE(loss='nll')
示例#2
0
def test_conve_evaluation_protocol():
    X = load_wn18()
    model = ConvE(batches_count=200,
                  seed=22,
                  epochs=1,
                  k=10,
                  embedding_model_params={
                      'conv_filters': 16,
                      'conv_kernel_size': 3
                  },
                  optimizer='adam',
                  optimizer_params={'lr': 0.01},
                  loss='bce',
                  loss_params={},
                  regularizer=None,
                  regularizer_params={
                      'p': 2,
                      'lambda': 1e-5
                  },
                  verbose=True,
                  low_memory=True)

    model.fit(X['train'])

    y1 = model.predict(X['test'][:5])

    save_model(model, 'model.tmp')
    del model
    model = restore_model('model.tmp')

    y2 = model.predict(X['test'][:5])

    assert np.all(y1 == y2)

    os.remove('model.tmp')
示例#3
0
def test_conve_fit_predict_save_restore():

    X = np.array([['a', 'y', 'b'], ['b', 'y', 'a'], ['a', 'y', 'c'],
                  ['c', 'y', 'a'], ['a', 'y', 'd'], ['c', 'y', 'd'],
                  ['b', 'y', 'c'], ['f', 'y', 'e']])

    X_test = np.array([['f', 'y', 'a'], ['f', 'y', 'b']])

    model = ConvE(batches_count=1,
                  seed=22,
                  epochs=1,
                  k=10,
                  embedding_model_params={
                      'conv_filters': 16,
                      'conv_kernel_size': 3
                  },
                  optimizer='adam',
                  optimizer_params={'lr': 0.01},
                  loss='bce',
                  loss_params={},
                  regularizer=None,
                  regularizer_params={
                      'p': 2,
                      'lambda': 1e-5
                  },
                  verbose=True,
                  low_memory=True)

    model.fit(X)

    y1 = model.predict(X_test)
    print(y1)

    save_model(model, 'model.tmp')
    del model
    model = restore_model('model.tmp')

    y2 = model.predict(X_test)

    assert np.all(y1 == y2)
    os.remove('model.tmp')