import numpy as np
import bayes_factor


if __name__=='__main__':

    iterations = 1e4

    N = 15
    m = 108
    errors = np.array([43, 59, 51, 38, 39, 53, 47, 50, 50, 59, 59, 45, 36, 46, 53])
    # errors = np.linspace(0,60,N).astype(np.int)
    print "errors =", errors

    B_10, p_errors_given_mu_sigma, p_errors_given_sigma, sigma0, mu1, sigma1 = bayes_factor.bayes_factor(errors, m, iterations)

    sigma1_upper = np.sqrt(mu1 * (1.0 - mu1))
    p_mu_sigma_given_errors = p_errors_given_mu_sigma *  1.0/sigma1_upper * 2.0
    integral_mc1 =  p_errors_given_mu_sigma.mean()
    p_mu_sigma_given_errors /= integral_mc1

    p_sigma_given_errors = p_errors_given_sigma * 2.0
    integral_mc0 = p_errors_given_sigma.mean()
    p_sigma_given_errors /= integral_mc0

    import matplotlib.mlab as mlab
    import matplotlib.pyplot as plt

    size = 500
    sigma_i = np.linspace(0.0, 0.5, size)
    print "m_range =", m_range
    print "N_range =", N_range
    print "mu =", mu_true
    print "sigma =", sigma_true
    print "M =", M

    BF = np.zeros((M, len(m_range), len(N_range)))
    for k, N in enumerate(N_range):
        for j, m in enumerate(m_range):
            for i in range(M):
                epsilon = np.random.beta(a=a, b=b, size=N)
                errors = np.random.binomial(n=m, p=epsilon)
                print i,') N =', N, ', m =', m
                print "\t errors:", errors
                B_10, p1, p0, a0, a1, b1 = bayes_factor.bayes_factor(errors, m, iterations)
                BF[i,j,k] = B_10
                print "\t p(errors|H0) =", p0.mean()    
                print "\t p(errors|H1) =", p1.mean()
                print "\t B_10 =", B_10
                print "\t B_01 =", 1.0/B_10

    BF_threshold = [1, 3, 20, 150]

    import matplotlib.pyplot as plt

    if mu_true==0.5:
        description = "p(Type I error)"
        ylabel = description
    else:
        description = "p(Type II error)"