示例#1
0
def compare_median_graphs(configs, threshold=.5, how=None):
    config_l = get_config_l(configs)
    varying_k, varying_v = _get_varying(configs)
    paths = [config_to_path(c) for c in config_l]
    cols = len(paths)

    fig, axs = plt.subplots(1, cols + 1, figsize=(10 * (cols + 1), 10))

    n, n_obs, true_g = config_l[0]['n'], config_l[0]['n_obs'], config_l[0]['true_graph']
    pos = Graph(n).GetCirclePos()

    with open(f"data/graph_{true_g}_{n}_{n_obs}.pkl", 'rb') as handle:
        g = pickle.load (handle)
    if how == 'circle':
        g.Draw(ax=axs[0], pos=pos)
    else:
        g.Draw(ax=axs[0])
    axs[0].set_title('true_graph', fontsize=20)

    for i in range(cols):
        with open(config_to_path(config_l[i]), 'rb') as handle:
            sampler = pickle.load(handle)
        adjm = str_list_to_median_graph(n, sampler.res['SAMPLES'], threshold=threshold)
        g_ = Graph(n)
        g_.SetFromAdjM(adjm)
        if how == 'circle':
            g_.Draw(ax = axs[i + 1], pos=pos)
        else:
            g_.Draw(ax = axs[i + 1])
        axs[i + 1].set_title(f"{varying_k}: {varying_v[i]}", fontsize=20)

    plt.show()
示例#2
0
def compare_traces_short(configs, log=False, burnin=0):
    config_l = get_config_l(configs)
    varying_k, varying_v = _get_varying(configs)
    paths = [config_to_path(c) for c in config_l]
    cols = len(paths)

    post_traces = []
    size_traces = []
    basis_traces = []
    init_bases = []
    for c in config_l:
        with open(config_to_path(c)[:-4] + f"_burnin-0.short", 'rb') as handle:
            sampler = pickle.load(handle)
        post_traces.append(sampler.posteriors)
        size_traces.append(sampler.sizes)
        basis_traces.append(sampler.bases)
        init_bases.append(sampler.last_params._basis)

    fig, axs = plt.subplots(3, cols, figsize=(10 * (cols), 10 * 3))

    for i in range(cols):
        axs[0, i].plot(post_traces[i][burnin:])
        axs[1, i].plot(size_traces[i][burnin:])
        axs[2, i].plot(basis_traces[i][burnin:])

    for i in range(cols):
        axs[0, i].set_title(f"{varying_k}: {varying_v[i]}", fontsize=20)

    ylabs = ["MCMC posterior", "sizes", "n_basis"]
    for i in range(len(ylabs)):
        axs[i, 0].set_ylabel(ylabs[i], rotation= 90, fontsize=20)

    plt.show()
示例#3
0
def compare_traces(configs, log=False, burnin=0):
    config_l = get_config_l(configs)
    varying_k, varying_v = _get_varying(configs)
    paths = [config_to_path(c) for c in config_l]
    cols = len(paths)

    all_visited_states = set()
    for c in config_l:
        with open(config_to_path(c), 'rb') as handle:
            sampler = pickle.load(handle)
        all_visited_states = all_visited_states.union(set(np.unique(sampler.res['SAMPLES'])))

    posts = []
    post_traces = []
    size_traces = []
    basis_traces = []
    init_bases = []
    for c in config_l:
        with open(config_to_path(c), 'rb') as handle:
            sampler = pickle.load(handle)
        post = sampler_to_post_dict(sampler, list(all_visited_states))
        if c['basis'] != 'edge':
            post = get_post_dict_cb_only(c['n'], post)
        posts.append(post)
        post_traces.append(np.array(sampler.res['LIK']) + np.array(sampler.res['PRIOR']))
        size_traces.append(list(map(lambda s: np.sum(_str_to_int_list(s)), sampler.res['SAMPLES'])))
        basis_traces.append(_get_basis_ct(sampler))
        init_bases.append(sampler.last_params._basis)


    fig, axs = plt.subplots(3, cols + 1, figsize=(10 * (cols + 1), 10 * 3))

    for i in range(cols):
        plot_true_posterior(posts[i], log, ax=axs[0, 0], label=f"{varying_k}: {varying_v[i]}")
        plot_true_posterior_edge_marginalized(posts[i], log, ax=axs[1, 0], label=f"{varying_k}: {varying_v[i]}")

        if config_l[i]['cob_freq'] is None and config_l[i]['basis'] != 'edge':
            basis = init_bases[i]
            with open(config_to_path(c), 'rb') as handle:
                sampler = pickle.load(handle)
            plot_true_posterior_cb_marginalized(posts[i], basis, log, ax=axs[2, 0], sampler=sampler)


        axs[0, i + 1].plot(post_traces[i][burnin:])
        axs[1, i + 1].plot(size_traces[i][burnin:])
        axs[2, i + 1].plot(basis_traces[i][burnin:])

    axs[0, 0].legend()
    axs[1, 0].legend()

    for i in range(cols):
        axs[0, i + 1].set_title(f"{varying_k}: {varying_v[i]}", fontsize=20)

    ylabs = ["MCMC posterior", "sizes", "n_basis"]
    for i in range(len(ylabs)):
        axs[i, 0].set_ylabel(ylabs[i], rotation= 90, fontsize=20)

    plt.show()
    return posts
def get_accuracies(config):
    with open(config_to_path(config), 'rb') as handle:
        sampler = pickle.load(handle)
    with open(f"data/graph_{config['true_graph']}_{config['n']}_{config['n_obs']}.pkl", 'rb') as handle:
        g = pickle.load (handle)

    adjm = str_list_to_adjm(len(g), sampler.res['SAMPLES'])
    median_g = (adjm > .5).astype(int)

    def _get_accuracies(sampler, g, md):
        l1= np.array(g.GetBinaryL(), dtype=bool)
        triu = np.triu_indices(len(g), 1)
        l2 = np.array(md[triu], dtype=bool)

        TP = np.logical_and(l1, l2).astype(int).sum()
        TN = np.logical_and(np.logical_not(l1), np.logical_not(l2)).astype(int).sum()
        FP = np.logical_and(np.logical_not(l1), l2).astype(int).sum()
        FN = np.logical_and(l1, np.logical_not(l2)).astype(int).sum()

        assert(TP + TN + FP + FN == len(l1))
        assert(TP + FP == l2.astype(int).sum())
        assert(TN + FN == np.logical_not(l2).astype(int).sum())

        return TP, TN, FP, FN

    return _get_accuracies(sampler, g, median_g)
def run(conf):
    n, n_obs = conf['n'], conf['n_obs']
    name = conf['true_graph']
    data = np.loadtxt(f"data/{name}_{n}_{n_obs}.dat", delimiter=',')
    sampler = run_config(data, conf)

    with open(
            f"data/graph_{conf['true_graph']}_{conf['n']}_{conf['n_obs']}.pkl",
            'rb') as handle:
        g = pickle.load(handle)

    for burnin in [0, int(.1 * sampler.iter), int(.25 * sampler.iter)]:
        print(f"saving to {config_to_path(conf)[:-4]}_burnin-{burnin}.short")
        with open(config_to_path(conf)[:-4] + f"_burnin-{burnin}.short",
                  'wb') as handle:
            pickle.dump(sampler.get_summary(g, burnin, thin=100), handle)
def get_summary(config, b=0):
    with open(config_to_path(config), 'rb') as handle:
        sampler = pickle.load(handle)

    def _str_to_int_list(s):
        return np.array(list(s), dtype=int)

    def _get_basis_ct(sampler):
        basis_ct = []
        if sampler.res['ACCEPT_INDEX'][0] == 0:
            basis_ct.append(np.sum(_str_to_int_list(sampler.init['BASIS_ID'])))
        else:
            basis_ct.append(np.sum(_str_to_int_list(sampler.res['PARAMS_PROPS'][0]['BASIS_ID'])))

        for i in range(1, len(sampler.res['ACCEPT_INDEX'])):
            if sampler.res['ACCEPT_INDEX'][i]:
                basis_ct.append(np.sum(_str_to_int_list(sampler.res['PARAMS_PROPS'][i]['BASIS_ID'])))
            else:
                basis_ct.append(basis_ct[-1])
        return basis_ct

    posts = np.array(sampler.res['LIK'], dtype=float)[b:] + np.array(sampler.res['PRIOR'], dtype=float)[b:]
    sizes = list(map(lambda s: np.sum(_str_to_int_list(s)), sampler.res['SAMPLES']))[b:]
    n_bases = _get_basis_ct(sampler)[b:]
    trees = [pp['TREE_ID'] for pp in sampler.res['PARAMS_PROPS']]
    change_tree = np.where(list(map(lambda t, t_: t != t_, trees[:-1], trees[1:])))[0] + 1

    d = {}
    d['IAT_posterior'] = IAC_time(posts)
    d['IAT_sizes'] = IAC_time(sizes)
    d['IAT_bases'] = IAC_time(n_bases)

    d['accept_rate'] = np.sum(sampler.res['ACCEPT_INDEX']) / len(sampler.res['ACCEPT_INDEX'])
    d['tree_accept_ct'] = len(set(change_tree).intersection(set(np.where(sampler.res['ACCEPT_INDEX'])[0])))
    d['max_posterior'] = np.max(posts)
    d['states_visited'] = len(np.unique(sampler.res['SAMPLES'][b:]))
    d['time'] = sampler.time

    return d
示例#7
0
def plot_end(configs,
             basis_list=['edge', 'hub', 'uniform'],
             burnin=0,
             thin=1,
             plot=False,
             temper=None):
    n, n_obs = configs['n'], configs['n_obs']
    fig, axs = plt.subplots(len(configs['true_graph']),
                            3,
                            figsize=(3 * 10, len(configs['true_graph']) * 10))
    plt.rc('xtick', labelsize=30)
    plt.rc('ytick', labelsize=30)

    # Setting (shared) x and y labels
    names = [BETTER_NAMES[s] for s in configs['true_graph']]

    for i in range(len(configs['true_graph'])):
        axs[i, 0].set_ylabel(names[i], size=50)

    axs[0, 0].set_title('jaccard', size=50)
    axs[0, 1].set_title('hamming', size=50)
    axs[0, 2].set_title('sizes', size=50)

    # Getting Ranges
    jacc_max = [.0] * len(configs['true_graph'])
    hamm_max = [.0] * len(configs['true_graph'])
    size_max = [.0] * len(configs['true_graph'])
    for basis in basis_list:
        configs['basis'] = basis
        config_l = get_config_l(configs)
        summaries = [get_summary(c, burnin, thin, temper) for c in config_l]

        for i in range(len(summaries)):
            if len(summaries[i]['jaccard_distances_end']) == 0:
                print(config_to_path(config_l[i]))
                summaries[i]['jaccard_distances_end'] = [0]
            if len(summaries[i]['hamming_distances_end']) == 0:
                print(config_to_path(config_l[i]))
                summaries[i]['hamming_distances_end'] = [0]
            if len(summaries[i]['size_distances_end']) == 0:
                print(config_to_path(config_l[i]))
                summaries[i]['size_distances_end'] = [0]

            if np.max(
                    summaries[i]['jaccard_distances_end']) * 100 > jacc_max[i]:
                jacc_max[i] = np.max(
                    summaries[i]['jaccard_distances_end']) * 100
            if np.max(summaries[i]['hamming_distances_end']) > hamm_max[i]:
                hamm_max[i] = np.max(summaries[i]['hamming_distances_end'])
            if np.max(summaries[i]['size_distances_end']) > size_max[i]:
                size_max[i] = np.max(summaries[i]['size_distances_end'])

    # Plotting
    for basis in basis_list:
        configs['basis'] = basis
        config_l = get_config_l(configs)
        summaries = [get_summary(c, burnin, thin, temper) for c in config_l]

        for i in range(len(summaries)):
            axs[i, 0].hist(summaries[i]['jaccard_distances_end'],
                           bins=np.arange(jacc_max[i] + 1) / 100,
                           label=BETTER_NAMES[basis],
                           alpha=.5,
                           density=True)
            axs[i, 1].hist(summaries[i]['hamming_distances_end'],
                           bins=np.arange(hamm_max[i] + 1),
                           label=BETTER_NAMES[basis],
                           alpha=.5,
                           density=True)
            axs[i, 2].hist(summaries[i]['size_distances_end'],
                           bins=np.arange(size_max[i] + 1),
                           label=BETTER_NAMES[basis],
                           alpha=.5,
                           density=True)

            axs[i, 0].legend(fontsize=30)
            axs[i, 1].legend(fontsize=30)
            axs[i, 2].legend(fontsize=30)

    fig.savefig(f"as_end_distr_n-{n}_n_obs-{n_obs}.pdf")

    if plot:
        plt.show()

    return fig
示例#8
0
def compare_with_true_posterior(config, burnin=0, log=False):
    n = config['n']
    n_obs = config['n_obs']
    name = config['true_graph']

    MC_post = MC_to_post_dict(n, name)
    with open(f"results/true_posterior_{name}_{n}_{n_obs}.pkl", 'rb') as handle:
        LA_post = pickle.load(handle)
    with open(config_to_path(config), 'rb') as handle:
        sampler = pickle.load(handle)
    MCMC_post = sampler_to_post_dict(sampler, MC_post.keys())

    if config['basis'] != 'edge':
        MC_post = get_post_dict_cb_only(n, MC_post)
        LA_post = get_post_dict_cb_only(n, LA_post)
        MCMC_post = get_post_dict_cb_only(n, MCMC_post)

    if config['cob_freq'] is None and config['basis'] != 'edge':
        fig, axs = plt.subplots(3, 2, figsize=(10 * 2, 10 * 3))
    else:
         fig, axs = plt.subplots(2, 2, figsize=(10 * 2, 10 * 2))

    plot_true_posterior(MC_post, log, ax=axs[0, 0], label="MC")
    plot_true_posterior(LA_post, log, ax=axs[0, 0], label="LA")
    plot_true_posterior(MCMC_post, log, ax=axs[0, 0], label="MCMC")
    axs[0, 0].legend()

    plot_true_posterior_edge_marginalized(MC_post, log, ax=axs[1, 0], label="MC")
    plot_true_posterior_edge_marginalized(LA_post, log, ax=axs[1, 0], label="LA")
    plot_true_posterior_edge_marginalized(MCMC_post, log, ax=axs[1, 0], label="MCMC")
    axs[1, 0].legend()

    if config['cob_freq'] is None and config['basis'] != 'edge':
        basis = sampler.last_params._basis
        plot_true_posterior_cb_marginalized(MC_post, basis, log, ax=axs[2, 0], label="MC")
        plot_true_posterior_cb_marginalized(LA_post, basis, log, ax=axs[2, 0], label="LA")
        plot_true_posterior_cb_marginalized(MCMC_post, basis, log, ax=axs[2, 0], label="MCMC")
        axs[2, 0].legend()


    posterior = np.array(sampler.res['LIK']) + np.array(sampler.res['PRIOR'])
    sizes = list(map(lambda s: np.sum(_str_to_int_list(s)), sampler.res['SAMPLES']))
    n_bases = _get_basis_ct(sampler)

    axs[0, 1].plot(posterior[burnin:])
    axs[1, 1].plot(sizes[burnin:])
    if config['cob_freq'] is None and config['basis'] != 'edge':
        axs[2, 1].plot(n_bases[burnin:])

    axs[0, 0].set_title("compare_w_true", fontsize=20)
    axs[0, 1].set_title(f"traces", fontsize=20)

    ylabs = ["graph_index", "edges"]
    for i in range(len(ylabs)):
        axs[i, 0].set_ylabel(ylabs[i], rotation= 90, fontsize=20)

    ylabs = ["log posterior", "sizes"]
    for i in range(len(ylabs)):
        axs[i, 1].set_ylabel(ylabs[i], rotation= 90, fontsize=20)

    if config['cob_freq'] is None and config['basis'] != 'edge':
        axs[2, 0].set_ylabel("basis", rotation= 90, fontsize=20)
        axs[2, 1].set_ylabel("n_basis", rotation= 90, fontsize=20)

    plt.show()
    return MC_post, LA_post, MCMC_post