def test_latent_distance_network_sampler(N, N_samples=10000):
    """
    Generate a bunch of latent distance networks, run the sampler
    on them to see how well we mix over latent locations.

    :param N: Number of neurons in the network
    """
    true_model_type = 'latent_distance'
    if true_model_type == 'erdos_renyi':
        true_model = make_model('sparse_weighted_model', N)
    elif true_model_type == 'latent_distance':
        true_model = make_model('distance_weighted_model', N)

    distmodel = make_model('distance_weighted_model', N)
    D = distmodel['network']['graph']['N_dims']
    trials = 1
    for t in range(trials):
        # Generate a true random network
        popn_true, x_true, A_true = sample_network_from_prior(true_model)
        dist_popn, x_inf, _ = sample_network_from_prior(distmodel)

        # Seed the inference population with the true network
        x_inf['net']['graph']['A'] = A_true

        # Create a location sampler
        print "Initializing latent location sampler"
        loc_sampler = LatentLocationUpdate()
        loc_sampler.preprocess(dist_popn)

        # Run the sampler
        N_samples = 1000
        smpls = fit_latent_network_given_A(x_inf, loc_sampler, N_samples=N_samples)

        if true_model_type == 'latent_distance':
            # Evaluate the state
            L_true = x_true['net']['graph']['L'].reshape((N,D))
            L_smpls = [x['net']['graph']['L'].reshape((N,D)) for x in smpls]

            # Visualize the results
            plot_latent_distance_samples(L_true, L_smpls, A_true=A_true)

            # Plot errors in relative distance over time
            compute_diff_of_dists(L_true, L_smpls)

        # Compute marginal likelihood of erdos renyi with the same sparsity
        nnz_A = float(A_true.sum())
        N_conns = A_true.size
        # Ignore the diagonal
        nnz_A -= N
        N_conns -= N
        # Now compute density
        er_rho = nnz_A / N_conns
        true_er_marg_lkhd = nnz_A * np.log(er_rho) + (N_conns-nnz_A)*np.log(1-er_rho)
        print "True ER Marg Lkhd: ", true_er_marg_lkhd

        # DEBUG: Make sure AIS gives the same answer as what we just computed
        # er_model = make_model('sparse_weighted_model', N)
        # er_model['network']['graph']['rho'] = er_rho
        # er_popn, x_inf, _ = sample_network_from_prior(er_model)
        # # Make a dummy update for the ER model
        # er_sampler = MetropolisHastingsUpdate()
        # er_x0 = er_popn.sample()
        # er_x0['net']['graph']['A'] = A_true
        # er_marg_lkhd = ais_latent_network_given_A(er_x0,
        #                                           er_popn.network.graph,
        #                                           er_sampler
        #                                           )
        #
        # print "AIS ER Marg Lkhd: ", er_marg_lkhd



        # Approximate the marginal log likelihood of the distance mode
        dist_x0 = dist_popn.sample()
        dist_x0['net']['graph']['A'] = A_true
        dist_marg_lkhd = ais_latent_network_given_A(dist_x0,
                                                    dist_popn.network.graph,
                                                    loc_sampler
                                                    )
        print "Dist Marg Lkhd: ", dist_marg_lkhd
def fit_latent_network_to_mle():
    """ Run a test with synthetic data and MCMC inference
    """
    options, popn, data, popn_true, x_true = initialize_test_harness()

    import pdb; pdb.set_trace()
    # Load MLE parameters from command line
    mle_x = None
    if options.x0_file is not None:
        with open(options.x0_file, 'r') as f:
            print "Initializing with state from: %s" % options.x0_file
            mle_x = cPickle.load(f)

            mle_model = make_model('standard_glm', N=data['N'])
            mle_popn = Population(mle_model)
            mle_popn.set_data(data)

    # Create a location sampler
    print "Initializing latent location sampler"
    loc_sampler = LatentLocationUpdate()
    loc_sampler.preprocess(popn)

    # Convert the mle results into a weighted adjacency matrix
    x_aw = popn.sample(None)
    x_aw = convert_model(mle_popn, mle_model, mle_x, popn, popn.model, x_aw)

    # Get rid of unnecessary keys
    del x_aw['glms']

    # Fit the latent distance network to a thresholded adjacency matrix
    ws = np.sort(np.abs(x_aw['net']['weights']['W']))

    wperm = np.argsort(np.abs(x_aw['net']['weights']['W']))
    nthrsh = 20
    threshs = np.arange(ws.size, step=ws.size/nthrsh)

    res = []

    N = popn.N
    for th in threshs:
        print "Fitting network for threshold: %.3f" % th
        A = np.zeros_like(ws, dtype=np.int8)
        A[wperm[th:]] = 1
        A = A.reshape((N,N))
        # A = (np.abs(x_aw['net']['weights']['W']) >= th).astype(np.int8).reshape((N,N))

        # Make sure the diag is still all 1s
        A[np.diag_indices(N)] = 1

        x = copy.deepcopy(x_aw)
        x['net']['graph']['A'] = A
        smpls = fit_latent_network_given_A(x, loc_sampler)

        # Index the results by the overall sparsity of A
        key = (np.sum(A)-N) / (np.float(np.size(A))-N)
        res.append((key, smpls))

    # Save results
    results_file = os.path.join(options.resultsDir, 'fit_latent_network_results.pkl')
    print "Saving results to %s" % results_file
    with open(results_file, 'w') as f:
        cPickle.dump(res, f)