def test_search(): # Check whether we get no exceptions... I, J = 10, 9 values_K = [1, 2, 4, 5] R = 2 * numpy.ones((I, J)) R[0, 0] = 1 M = numpy.ones((I, J)) priors = {'alpha': 3, 'beta': 4, 'lambdaU': 5, 'lambdaV': 6} initUV = 'exp' iterations = 1 linesearch = LineSearch(classifier, values_K, R, M, priors, initUV, iterations) linesearch.search()
# Generate data (_, _, _, _, R) = generate_dataset(I, J, true_K, lambdaU, lambdaV, tau) M = numpy.ones((I, J)) #M = try_generate_M(I,J,fraction_unknown,attempts_M) # Run the line search. The priors lambdaU and lambdaV need to be a single value (recall K is unknown) priors = { 'alpha': alpha, 'beta': beta, 'lambdaU': lambdaU[0, 0] / 10, 'lambdaV': lambdaV[0, 0] / 10 } line_search = LineSearch(classifier, values_K, R, M, priors, initUV, iterations, restarts) line_search.search(burn_in, thinning) # Plot the performances of all three metrics - but MSE separately metrics = ['loglikelihood', 'BIC', 'AIC', 'MSE'] for metric in metrics: plt.figure() plt.plot(values_K, line_search.all_values(metric), label=metric) plt.legend(loc=3) # Also print out all values in a dictionary all_values = {} for metric in metrics: all_values[metric] = line_search.all_values(metric) print "all_values = %s" % all_values '''
# Generate data (_, _, _, _, R) = generate_dataset(I, J, true_K, lambdaU, lambdaV, tau) M = numpy.ones((I, J)) #M = try_generate_M(I,J,fraction_unknown,attempts_M) # Run the line search. The priors lambdaU and lambdaV need to be a single value (recall K is unknown) priors = { 'alpha': alpha, 'beta': beta, 'lambdaU': lambdaU[0, 0] / 10, 'lambdaV': lambdaV[0, 0] / 10 } line_search = LineSearch(classifier, values_K, R, M, priors, initUV, iterations, restarts) line_search.search() # Plot the performances of all three metrics metrics = ['loglikelihood', 'BIC', 'AIC', 'MSE', 'ELBO'] for metric in metrics: plt.figure() plt.plot(values_K, line_search.all_values(metric), label=metric) plt.legend(loc=3) # Also print out all values in a dictionary all_values = {} for metric in metrics: all_values[metric] = line_search.all_values(metric) print "all_values = %s" % all_values '''