Exemplo n.º 1
0
def test_search():
    # Check whether we get no exceptions...
    I, J = 10, 9
    values_K = [1, 2, 4, 5]
    R = 2 * numpy.ones((I, J))
    R[0, 0] = 1
    M = numpy.ones((I, J))
    priors = {'alpha': 3, 'beta': 4, 'lambdaU': 5, 'lambdaV': 6}
    initUV = 'exp'
    iterations = 1

    linesearch = LineSearch(classifier, values_K, R, M, priors, initUV,
                            iterations)
    linesearch.search()
Exemplo n.º 2
0
def test_init():
    I, J = 10, 9
    values_K = [1, 2, 4, 5]
    R = 2 * numpy.ones((I, J))
    M = numpy.ones((I, J))
    priors = {'alpha': 3, 'beta': 4, 'lambdaU': 5, 'lambdaV': 6}
    initUV = 'exp'
    iterations = 11

    linesearch = LineSearch(classifier, values_K, R, M, priors, initUV,
                            iterations)
    assert linesearch.I == I
    assert linesearch.J == J
    assert numpy.array_equal(linesearch.values_K, values_K)
    assert numpy.array_equal(linesearch.R, R)
    assert numpy.array_equal(linesearch.M, M)
    assert linesearch.priors == priors
    assert linesearch.iterations == iterations
    assert linesearch.initUV == initUV
    assert linesearch.all_performances == {
        'BIC': [],
        'AIC': [],
        'loglikelihood': [],
        'MSE': [],
        'ELBO': []
    }
Exemplo n.º 3
0
def test_best_value():
    I, J = 10, 9
    values_K = [1, 2, 4, 5]
    R = 2 * numpy.ones((I, J))
    M = numpy.ones((I, J))
    priors = {'alpha': 3, 'beta': 4, 'lambdaU': 5, 'lambdaV': 6}
    initUV = 'exp'
    iterations = 11

    linesearch = LineSearch(classifier, values_K, R, M, priors, initUV,
                            iterations)
    linesearch.all_performances = {
        'BIC': [10, 9, 8, 7],
        'AIC': [11, 13, 12, 14],
        'loglikelihood': [16, 15, 18, 17],
        'MSE': [16, 20, 18, 17]
    }
    assert linesearch.best_value('BIC') == 5
    assert linesearch.best_value('AIC') == 1
    assert linesearch.best_value('loglikelihood') == 2
    assert linesearch.best_value('MSE') == 1
    with pytest.raises(AssertionError) as error:
        linesearch.all_values('FAIL')
    assert str(error.value) == "Unrecognised metric name: FAIL."
Exemplo n.º 4
0
classifier = bnmf_vb_optimised

# Load in data
(_, X_min, M, _, _, _, _) = load_gdsc(standardised=standardised)

folds_test = compute_folds(I, J, no_folds, M)
folds_training = compute_Ms(folds_test)
(M_train, M_test) = (folds_training[0], folds_test[0])

# Run the line search
priors = {'alpha': alpha, 'beta': beta, 'lambdaU': lambdaU, 'lambdaV': lambdaV}
line_search = LineSearch(classifier,
                         values_K,
                         X_min,
                         M,
                         priors,
                         initUV,
                         iterations,
                         restarts=restarts)
line_search.search()

# Plot the performances of all four metrics
metrics = ['loglikelihood', 'BIC', 'AIC', 'MSE']
for metric in metrics:
    plt.figure()
    plt.plot(values_K, line_search.all_values(metric), label=metric)
    plt.legend(loc=3)

# Also print out all values in a dictionary
all_values = {}
for metric in metrics: