Esempio n. 1
0
def summarize_bd(D, n_blocks, zero_thresh=None, lap='sym'):

    assert lap in ['sym', 'un']
    comm_summary, Pi_comm = community_summary(D, zero_thresh=zero_thresh)

    print(comm_summary)

    plt.figure(figsize=(8, 4))
    if lap == 'sym':
        evals = eigh_Lsym_bp(D)[0]
    else:
        Lun = get_unnorm_laplacian_bp(D)
        evals = eigh_wrapper(Lun)[0]

    plt.subplot(1, 2, 1)
    plt.plot(evals, marker='.')
    plt.title('all evals of L_{}'.format(lap))
    plt.subplot(1, 2, 2)
    plt.plot(evals[-n_blocks:], marker='.')
    plt.title('smallest {} evals'.format(n_blocks))
    print('evals', evals)

    # print('found {} communities of sizes {}'.format(summary['n_communities'], summary['comm_shapes']))

    plt.figure()
    sns.heatmap(Pi_comm, cmap='Blues', square=True, cbar=False, vmin=0)
    plt.xlabel('View 1 clusters')
    plt.ylabel('View 2 clusters')
Esempio n. 2
0
def plot_pi_ests_bd_mvmm(fit_models, zero_thresh, inches=10):
    n_samples_tr_seq = np.sort(list(fit_models.keys()))

    n_blocks_seq = fit_models[n_samples_tr_seq[0]]['bd_mvmm'].param_grid_

    ncols = len(n_samples_tr_seq)
    nrows = len(n_blocks_seq)

    plt.figure(figsize=(ncols * inches, nrows * inches))
    grid = plt.GridSpec(nrows=nrows, ncols=ncols, wspace=0.2, hspace=0.2)
    for c, n_samples in enumerate(n_samples_tr_seq):

        mvmm_gs = fit_models[n_samples]['bd_mvmm']

        for r, est in enumerate(mvmm_gs.estimators_):

            Pi_est_blk = est.final_.bd_weights_
            summary, Pi_comm = community_summary(Pi_est_blk,
                                                 zero_thresh=zero_thresh)
            n_blocks = est.final_.n_blocks
            n_blocks_est = summary['n_communities']

            title = 'n_samples={}, n_blocks={} (est={})\n{}'.\
                format(n_samples, n_blocks, n_blocks_est,
                       summary['comm_shapes'])

            plt.subplot(grid[r, c])
            sns.heatmap(Pi_comm.T,
                        cmap='Blues',
                        square=True,
                        cbar=False,
                        vmin=0)
            plt.title(title)
            plt.xlabel("View 1 clusters")
            plt.xlabel("View 2 clusters")
Esempio n. 3
0
    def plot_pi_from_idx(idx):
        estimator = sp_mvmm.estimators_[idx]

        if idx == 0:
            Pi = estimator.weights_mat_
            title = 'MVMM'
            zero_thresh = 0

        else:
            Pi = estimator.bd_weights_
            title = 'Penalty = {:1.2f}'.format(estimator.eval_pen_base)
            zero_thresh = estimator.zero_thresh

        evals, _ = eigh_Lsym_bp(Pi)

        comm_summary, Pi_comm = community_summary(Pi, zero_thresh=zero_thresh)
        n_blocks = comm_summary['n_communities']

        plt.figure(figsize=(16, 5))

        plt.subplot(1, 3, 1)
        sns.heatmap(Pi, cmap='Blues', square=True, cbar=False, vmin=0)
        plt.title(title)

        plt.subplot(1, 3, 2)
        sns.heatmap(Pi_comm, cmap='Blues', square=True, cbar=False, vmin=0)
        plt.title('n_blocks = {}'.format(n_blocks))

        plt.subplot(1, 3, 3)
        plt.plot(evals, marker='.')
Esempio n. 4
0
def plot_pi_ests_log_pen(fit_models, zero_thresh, nrows=5, inches=10):

    select_metric = 'bic'  # TODO: if we want other ones need to change
    # code below

    n_samples_tr_seq = np.sort(list(fit_models.keys()))

    ncols = len(n_samples_tr_seq)

    plt.figure(figsize=(ncols * inches, nrows * inches))
    grid = plt.GridSpec(nrows=nrows, ncols=ncols, wspace=0.2, hspace=0.2)
    for c, n_samples in enumerate(n_samples_tr_seq):

        gs = fit_models[n_samples]['log_pen_mvmm']

        meow = [{
            'tune_idx': i,
            'n_comp': est.n_components,
            'bic': gs.model_sel_scores_.iloc[i][select_metric]
        } for i, est in enumerate(gs.estimators_)]

        meow = pd.DataFrame(meow)

        tune_idxs = []
        for _, df in meow.groupby('n_comp'):
            idx_best = df['bic'].idxmin()
            tune_idx = int(df.loc[idx_best]['tune_idx'])
            tune_idxs.append(tune_idx)

        r = 0
        for tune_idx in tune_idxs:
            est = gs.estimators_[tune_idx].final_

            n_comp = est.n_components

            Pi_est = est.weights_mat_
            summary, Pi_comm = community_summary(Pi_est,
                                                 zero_thresh=zero_thresh)
            n_blocks_est = summary['n_communities']

            title = 'n_samples={}, n_components={} (n_blocks={})\n{}'.\
                format(n_samples, n_comp, n_blocks_est, summary['comm_shapes'])

            plt.subplot(grid[r, c])
            sns.heatmap(Pi_comm.T,
                        cmap='Blues',
                        square=True,
                        cbar=False,
                        vmin=0)
            plt.title(title)
            plt.xlabel("View 1 clusters")
            plt.xlabel("View 2 clusters")
            if r == nrows - 1:
                continue
            else:
                r += 1
Esempio n. 5
0
def get_bd_summary_for_gs(sim_stub, gs_mvmm, zero_thresh=0):
    """
    Summary statistics of community structure of Pi for each
    estimator in a grid search
    """

    # is_spect_pen_cts = isinstance(gs_mvmm, SpectralPenSearchMVMM)
    is_spect_pen_cts = False  # TODO: just remove this
    is_spect_pen_block = isinstance(gs_mvmm, SpectralPenSearchByBlockMVMM)

    results_df = pd.DataFrame()
    for tune_idx, estimator in enumerate(gs_mvmm.estimators_):

        if isinstance(estimator, TwoStage):
            estimator = deepcopy(estimator.final_)

        res = deepcopy(sim_stub)
        res['tune_idx'] = tune_idx

        if is_spect_pen_cts:
            if isinstance(estimator, BlockDiagMVMM):
                res['tune__sp_pen'] = estimator.eval_pen_base
                Pi = estimator.bd_weights_
            else:
                res['tune__sp_pen'] = 0
                Pi = estimator.weights_mat_

        elif is_spect_pen_block:
            res['tune__n_blocks'] = gs_mvmm.est_n_blocks_[tune_idx]

            if isinstance(estimator, BlockDiagMVMM):
                Pi = estimator.bd_weights_
            else:
                Pi = estimator.weights_mat_

        elif isinstance(estimator, BlockDiagMVMM):
            res['tune__n_blocks'] = estimator.n_blocks
            Pi = estimator.bd_weights_

        elif isinstance(estimator, LogPenMVMM):
            res['tune__pen'] = estimator.pen
            # res['n_components'] = estimator.n_components
            Pi = estimator.weights_mat_

        # Pi = estimator.weights_mat_

        summary, _ = community_summary(Pi, zero_thresh=zero_thresh)
        res.update(summary)

        results_df = results_df.append(res, ignore_index=True)

    return results_df
Esempio n. 6
0
def plot_pi_ests_sp_mvmm(fit_models, zero_thresh, inches=10, nrows=5):
    n_samples_tr_seq = np.sort(list(fit_models.keys()))
    ncols = len(n_samples_tr_seq)

    plt.figure(figsize=(ncols * inches, nrows * inches))
    grid = plt.GridSpec(nrows=nrows, ncols=ncols, wspace=0.2, hspace=0.2)

    for c, n_samples in enumerate(n_samples_tr_seq):

        estimators = fit_models[n_samples]['sp_mvmm'].estimators_

        for r in range(nrows):

            if r >= len(estimators):
                continue
            else:
                est = estimators[r]

            if type(est) == MVMM:
                Pi = est.weights_mat_
            else:
                Pi = est.bd_weights_

            summary, Pi_comm = community_summary(Pi, zero_thresh=zero_thresh)
            n_blocks_est = summary['n_communities']

            title = 'n_samples={}, est_n_blocks={}\n{}'.\
                format(n_samples, n_blocks_est, summary['comm_shapes'])

            plt.subplot(grid[r, c])
            sns.heatmap(Pi_comm.T,
                        cmap='Blues',
                        square=True,
                        cbar=False,
                        vmin=0)
            plt.title(title)
            plt.xlabel("View 1 clusters")
            plt.xlabel("View 2 clusters")
Esempio n. 7
0
def get_bd_results(sim_stub, mvmm_gs, zero_thresh=0):
    """
    Summary statistics of community structure of Pi
    """

    results_df = pd.DataFrame()
    for tune_idx, estimator in enumerate(mvmm_gs.estimators_):

        if isinstance(estimator, TwoStage):
            estimator = deepcopy(estimator.final_)

        res = deepcopy(sim_stub)
        res['tune_idx'] = tune_idx
        res['n_blocks'] = estimator.n_blocks

        # Pi = estimator.weights_mat_
        Pi = estimator.bd_weights_
        summary, _ = community_summary(Pi, zero_thresh=zero_thresh)
        res.update(summary)

        results_df = results_df.append(res, ignore_index=True)

    return results_df
Esempio n. 8
0
def plot_log_pen_mvmm(mvmm, inches=8, save_dir=None):

    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    # info = get_log_pen_mvmm_info(mvmm)
    # if save_dir is not None:
    #     # TODO: save this
    #     save_dir
    # else:
    #     print(info)

    Pi = mvmm.weights_mat_
    zero_thresh = 1e-10  # not sure if we need this

    summary, Pi_comm = community_summary(Pi, zero_thresh=zero_thresh)
    Pi_symlap_spec = eigh_Lsym_bp(Pi)[0]

    if 'init_params' in mvmm.opt_data_['history']:
        Pi_init = mvmm.opt_data_['history']['weights'].reshape(
            Pi.shape)  # TODO: check this
        Pi_init_symlap_spec = eigh_Lsym_bp(Pi_init)[0]
    else:
        Pi_init = None

    obs_nll = mvmm.opt_data_['history']['obs_nll']
    loss_vals = mvmm.opt_data_['history']['loss_val']

    ####################
    # Initial weights #
    ###################
    if Pi_init is not None:

        plt.figure(figsize=(inches, inches))
        plot_Pi(Pi_init)
        plt.title('weights initial value')

        if save_dir is not None:
            fpath = join(save_dir, 'weights_init.png')
            save_fig(fpath)

    ######################
    # Estimated  weights #
    ######################
    plt.figure(figsize=(2 * inches, inches))
    plt.subplot(1, 2, 1)
    plot_Pi(Pi)
    plt.title('weights estimate, n_blocks={}'.format(summary['n_communities']))

    plt.subplot(1, 2, 2)
    plot_Pi(Pi, mask=Pi_comm < zero_thresh)
    plt.title('weights estimate, block diagonal perm')

    if save_dir is not None:
        fpath = join(save_dir, 'weights_est.png')
        save_fig(fpath)

    ##########################
    # Spectrum of BD weights #
    ##########################
    plt.figure(figsize=(inches, inches))
    idxs = np.arange(1, len(Pi_symlap_spec) + 1)
    plt.plot(idxs, Pi_symlap_spec, marker='.', label='Estimate')
    if Pi_init is not None:
        plt.plot(idxs, Pi_init_symlap_spec, marker='.', label="Initial")
    plt.title('weights estimate spectrum')
    plt.ylim(0)
    plt.legend()

    if save_dir is not None:
        fpath = join(save_dir, 'weights_spectrum.png')
        save_fig(fpath)

    ###########################
    # Obs NLL for entire path #
    ###########################
    plt.figure(figsize=[inches, inches])
    plot_loss_history(obs_nll, loss_name="Obs NLL")

    if save_dir is not None:
        fpath = join(save_dir, 'obs_nll.png')
        save_fig(fpath)

    plt.figure(figsize=[inches, inches])
    plot_loss_history(loss_vals, loss_name="log penalized obs nll")

    if save_dir is not None:
        fpath = join(save_dir, 'loss_vals.png')
        save_fig(fpath)
Esempio n. 9
0
def plot_bd_mvmm(mvmm, inches=8, save_dir=None):
    """
    Initial BD weights, Estimated BD weights, spectrums of both
    Number of steps in each adaptive stage
    Evals of entire path
    Loss history for each segment
    """
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    info = get_bd_mvmm_info(mvmm)
    if save_dir is not None:
        # TODO: save this
        save_dir
    else:
        print(info)

    # BD weight estimate
    bd_weights = mvmm.bd_weights_
    zero_thresh = mvmm.zero_thresh
    summary, Pi_comm = community_summary(bd_weights, zero_thresh=zero_thresh)
    bd_weights_symlap_spec = eigh_Lsym_bp(bd_weights)[0]

    # initial BD weights
    bd_weights_init = mvmm.opt_data_['adpt_opt_data']['adapt_pen_history'][
        'opt_data'][0]['history']['init_params']['bd_weights']
    bd_weights_init_symlap_spec = eigh_Lsym_bp(bd_weights_init)[0]

    # optimization history
    adpt_history = mvmm.opt_data_['adpt_opt_data']['adapt_pen_history'][
        'opt_data']

    if 'raw_eval_sum' in adpt_history[0]['history']:
        n_steps = [
            len(adpt_history[i]['history']['raw_eval_sum'])
            for i in range(len(adpt_history))
        ]
        n_steps_cumsum = np.cumsum(n_steps)

        raw_eval_sum = \
            np.concatenate([adpt_history[i]['history']['raw_eval_sum']
                           for i in range(len(adpt_history))])

    else:
        raw_eval_sum = None
        n_steps = None
        n_steps_cumsum = None

    obs_nll = np.concatenate([
        adpt_history[i]['history']['obs_nll'] for i in range(len(adpt_history))
    ])

    if mvmm.opt_data_['ft_opt_data'] is not None:
        fine_tune_obs_nll = mvmm.opt_data_['ft_opt_data']['history']['obs_nll']
    else:
        fine_tune_obs_nll = None

    ######################
    # Initial BD weights #
    ######################
    plt.figure(figsize=(inches, inches))
    plot_Pi(bd_weights_init)
    plt.title('BD weights initial value')

    if save_dir is not None:
        fpath = join(save_dir, 'BD_weights_init.png')
        save_fig(fpath)

    ########################
    # Estimated BD weights #
    ########################
    plt.figure(figsize=(2 * inches, inches))
    plt.subplot(1, 2, 1)
    plot_Pi(bd_weights)
    plt.title('BD weights estimate, n_blocks={}'.format(
        summary['n_communities']))

    plt.subplot(1, 2, 2)
    plot_Pi(bd_weights, mask=Pi_comm < zero_thresh)
    plt.title('BD weights estimate, block diagonal perm')

    if save_dir is not None:
        fpath = join(save_dir, 'BD_weights_est.png')
        save_fig(fpath)

    ##########################
    # Spectrum of BD weights #
    ##########################
    plt.figure(figsize=(inches, inches))
    idxs = np.arange(1, len(bd_weights_symlap_spec) + 1)
    plt.plot(idxs, bd_weights_symlap_spec, marker='.', label='Estimate')
    plt.plot(idxs, bd_weights_init_symlap_spec, marker='.', label="Initial")
    plt.title('BD weights estimate spectrum')
    plt.ylim(0)
    plt.legend()

    if save_dir is not None:
        fpath = join(save_dir, 'BD_weights_spectrum.png')
        save_fig(fpath)

    ##################################
    # Number of steps for each stage #
    ##################################

    if n_steps is not None:
        plt.figure(figsize=(inches, inches))
        idxs = np.arange(1, len(n_steps) + 1)
        plt.plot(idxs, n_steps, marker='.')
        plt.ylim(0)
        plt.ylabel("Number of steps")
        plt.xlabel("Adaptive stage")

        if save_dir is not None:
            fpath = join(save_dir, 'n_steps.png')
            save_fig(fpath)

    ###########################
    # Obs NLL for entire path #
    ###########################
    plt.figure(figsize=[inches, inches])
    plot_loss_history(obs_nll, loss_name="Obs NLL (entire path)")

    if save_dir is not None:
        fpath = join(save_dir, 'path_obs_nll.png')
        save_fig(fpath)

    #########################
    # Evals for entire path #
    #########################
    if raw_eval_sum is not None:
        plt.figure(figsize=[inches, inches])
        plt.plot(np.log10(raw_eval_sum), marker='.')
        plt.ylabel('log10(sum smallest evals)')
        plt.xlabel('step')
        plt.title('Eigenvalue history (entire path)')
        for s in n_steps_cumsum:
            plt.axvline(s - 1, color='grey')

        if save_dir is not None:
            fpath = join(save_dir, 'path_evals.png')
            save_fig(fpath)

    ###########################
    # Losses for each segment #
    ###########################
    if save_dir is not None:
        segment_dir = join(save_dir, 'segments')
        os.makedirs(segment_dir, exist_ok=True)

    for i in range(len(adpt_history)):
        loss_vals = adpt_history[i]['history']['loss_val']
        plot_loss_history(loss_vals,
                          'loss val, adapt segment {}'.format(i + 1))

        if save_dir is not None:
            fpath = join(segment_dir, 'loss_history_{}.png'.format(i + 1))
            save_fig(fpath)

    ##########################
    # fine tune loss history #
    ##########################
    if fine_tune_obs_nll is not None:
        plot_loss_history(fine_tune_obs_nll, 'fine tune obs NLL')

        if save_dir is not None:
            fpath = join(segment_dir, 'fine_tune_loss_history.png')
            save_fig(fpath)
Esempio n. 10
0
    'n_jobs_tune': n_jobs_tune
}

#################################
# sample data and create models #
#################################

data_dist, Pi_true, view_params = \
    get_data_dist(clust_param_config=clust_param_config,
                  grid_means=args.grid_means,
                  pi_dist=pi_config['dist'],
                  pi_config=pi_config['config'])

n_view_components = Pi_true.shape
n_comp_true = (Pi_true > 0).sum()
n_blocks_true = community_summary(Pi_true)[0]['n_communities']

n_samples_tr = 100 * n_comp_true

X_tr, Y_tr = data_dist(n_samples=n_samples_tr, random_state=sample_seed(rng))

X_tst, Y_tst = data_dist(n_samples=n_samples_tst,
                         random_state=sample_seed(rng))

res_writer.write('\n\n\n data paramters')
res_writer.write(clust_param_config['custom_view_kws'])
res_writer.write('n_samples_tr {}'.format(n_samples_tr))
res_writer.write('Pi_dist {}'.format(args.pi_name))

# update config with true values
# cat gmm, view gmms and block diag mvmm are all fit at true values
Esempio n. 11
0
def get_mvmm_interpret_data(model,
                            view_data,
                            super_data=None,
                            vars2compare=None,
                            survival_df=None,
                            stub=None,
                            dataset_names=None,
                            n_top_samples=5,
                            clust_size_min=0):

    n_views = len(view_data)
    n_samples = view_data[0].shape[0]
    # dims = [view_data[v].shape[1] for v in range(n_views)]

    for v in range(n_views):
        assert isinstance(view_data[v], pd.DataFrame)

    sample_names = view_data[0].index.values
    view_feat_names = [view_data[v].columns.values for v in range(n_views)]

    view_data = [view_data[v].values for v in range(n_views)]

    if vars2compare is not None:
        vars2compare = vars2compare.loc[sample_names, :]
    if survival_df is not None:
        survival_df = survival_df.loc[sample_names, :]

    if super_data is None:
        super_data = [None] * n_views

    out = {}

    n_joint_comp = model.n_components
    n_view_comp = model.n_view_components

    if dataset_names is None:
        dataset_names = ['view_{}'.format(v + 1) for v in range(n_views)]

    if stub is None:
        stub = 'joint'

    # create labels for clusters
    # joint_cluster_labels = np.array(['{}__{}'.format(stub, k + 1)
    #                                  for k in range(n_joint_comp)])

    # view_cluster_labels = [np.array(['{}__{}'.format(dataset_names[v], k + 1)
    #                                  for k in range(n_view_comp[v])])
    #                        for v in range(n_views)]
    joint_cluster_labels = (1 + np.arange(n_joint_comp)).astype(str)

    view_cluster_labels = [(1 + np.arange(n_view_comp[v])).astype(str)
                           for v in range(n_views)]
    #######################
    # cluster predictions #
    #######################

    # map joint clusters to the view marginal clusters
    view_cl_idxs = np.array(
        [model._get_view_clust_idx(k) for k in range(model.n_components)])

    view_cl_labels = view_cl_idxs.astype(str)
    for k, v in product(range(n_joint_comp), range(n_views)):
        view_cl_labels[k, v] = view_cluster_labels[v][view_cl_idxs[k, v]]
    view_cl_labels = pd.DataFrame(view_cl_labels,
                                  index=joint_cluster_labels,
                                  columns=dataset_names)

    top_col_names = ['top_' + str(k + 1) for k in range(n_top_samples)]

    ################
    # joint labels #
    ################

    # predictions
    y_pred_joint = model.predict(view_data)
    joint_summary = summarize_mm_clusters(y_pred=y_pred_joint,
                                          weights=model.weights_,
                                          cluster_labels=joint_cluster_labels)
    joint_summary = joint_summary.join(view_cl_labels)
    out['joint_summary'] = joint_summary

    # convert to cluster labels
    y_pred_joint = [joint_cluster_labels[y] for y in y_pred_joint]
    y_pred_joint = pd.Series(y_pred_joint,
                             index=sample_names,
                             name='joint_clusters')
    out['y_pred_joint'] = y_pred_joint

    # ignore small classes for below comparisons!
    y_pred_joint = drop_small_classes(y=y_pred_joint,
                                      clust_size_min=clust_size_min)

    # top samples
    cl_prob_joint = model.predict_proba(view_data)
    joint_clust_best_samples = [
        argmax(cl_prob_joint[:, k], n=n_top_samples)
        for k in range(n_joint_comp)
    ]
    joint_clust_best_samples = \
        np.array([sample_names[best_idxs]
                 for best_idxs in joint_clust_best_samples])

    joint_clust_best_samples = pd.DataFrame(joint_clust_best_samples,
                                            index=joint_cluster_labels,
                                            columns=top_col_names)
    out['joint_clust_best_samples'] = joint_clust_best_samples

    # metadata comparisons
    if vars2compare is not None:
        vars2compare = vars2compare.loc[y_pred_joint.index, :]

        compare_kws = {
            'alpha': 0.05,
            'cat_test': 'auc',
            'corr': 'pearson',
            'multi_cat': 'ovo',
            'multi_test': 'fdr_bh',
            'nan_how': 'drop'
        }

        joint_comparison = BlockBlock(**compare_kws)
        joint_comparison.fit(y_pred_joint, vars2compare).correct_multi_tests()

        out['joint_comparison'] = joint_comparison

    # survival
    if survival_df is not None:
        _survival_df = survival_df.loc[y_pred_joint.index, :]

        out['joint_survival'] = get_survival(survival_df=_survival_df,
                                             y=y_pred_joint)

    ##############
    # view level #
    ##############

    # predictions
    y_pred_view = model.predict_view_labels(view_data)

    view_summaries = {}
    for v in range(n_views):
        name = dataset_names[v]

        c = model.view_models_[v].covariances_
        w = model.view_models_[v].weights_
        y = y_pred_view[:, v]

        view_summaries[name] = \
            summarize_mm_clusters(y_pred=y, weights=w, covs=c,
                                  cluster_labels=view_cluster_labels[v])

    out['view_summaries'] = view_summaries

    # convert to cluster labels
    # y_pred_view = _y_pred_view.copy()  # .astype(str)
    # for i, v in product(range(n_samples), range(n_views)):
    #     y_pred_view[i, v] = view_cluster_labels[v][_y_pred_view[i, v]]
    y_pred_view = pd.DataFrame(
        y_pred_view + 1,  # convert to 1 indexing
        index=sample_names,
        columns=dataset_names).astype(str)

    out['y_pred_view'] = y_pred_view

    # top samples
    view_cl_prob = model.predict_view_marginal_probas(view_data)
    view_clust_best_samples = {}
    for v in range(n_views):
        name = dataset_names[v]

        vc_best_samples = [
            argmax(view_cl_prob[v][:, k], n=n_top_samples)
            for k in range(n_view_comp[v])
        ]

        vc_best_samples = np.array(
            [sample_names[best_idxs] for best_idxs in vc_best_samples])

        view_clust_best_samples[name] = \
            pd.DataFrame(vc_best_samples,
                         index=view_cluster_labels[v],
                         columns=top_col_names)

    out['view_clust_best_samples'] = view_clust_best_samples

    # metadata comparisons
    if vars2compare is not None:
        view_comparisons = {}
        for v in range(n_views):
            name = dataset_names[v]

            y = y_pred_view.iloc[:, v]
            # ignore small classes
            y = drop_small_classes(y=y, clust_size_min=clust_size_min)

            _vars2compare = vars2compare.loc[y.index, :]

            c = BlockBlock(**compare_kws)
            c.fit(y, _vars2compare).correct_multi_tests()
            view_comparisons[name] = c

        out['view_comparisons'] = view_comparisons

    # survival
    if survival_df is not None:
        view_survival = {}
        for v in range(n_views):
            name = dataset_names[v]

            y = y_pred_view.iloc[:, v]
            # ignore small classes
            y = drop_small_classes(y=y, clust_size_min=clust_size_min)

            _survival_df = survival_df.loc[y.index, :]

            view_survival[name] = get_survival(survival_df=_survival_df, y=y)

        out['view_survival'] = view_survival

    # get cluster means for each view
    out['view_cl_means'] = {}
    out['view_stand_cl_means'] = {}
    out['view_cl_super_means'] = {}
    out['view_stand_cl_super_means'] = {}
    for v in range(n_views):
        name = dataset_names[v]

        cl_means = model.view_models_[v].means_
        processor = StandardScaler(with_mean=True,
                                   with_std=True).fit(view_data[v])
        stand_cl_means = processor.transform(cl_means)

        cl_means = pd.DataFrame(cl_means,
                                index=view_cluster_labels[v],
                                columns=view_feat_names[v])

        stand_cl_means = pd.DataFrame(stand_cl_means,
                                      index=view_cluster_labels[v],
                                      columns=view_feat_names[v])

        out['view_cl_means'][name] = cl_means
        out['view_stand_cl_means'][name] = stand_cl_means

        if super_data[v] is not None:
            super_feat_names = super_data[v].columns.values
            resp = view_cl_prob[v]

            cl_super_means = get_super_means(resp=resp,
                                             super_data=super_data[v],
                                             stand=False)

            stand_cl_super_means = get_super_means(resp=resp,
                                                   super_data=super_data[v],
                                                   stand=True)

            cl_super_means = pd.DataFrame(cl_super_means,
                                          index=view_cluster_labels[v],
                                          columns=super_feat_names)

            stand_cl_super_means = pd.DataFrame(stand_cl_super_means,
                                                index=view_cluster_labels[v],
                                                columns=super_feat_names)

            out['view_cl_super_means'][name] = cl_super_means
            out['view_stand_cl_super_means'][name] = stand_cl_super_means

    ###############
    # Block level #
    ###############
    if isinstance(model, BlockDiagMVMM):
        Pi = model.bd_weights_
        zero_thresh = model.zero_thresh

    elif isinstance(model, LogPenMVMM):
        Pi = model.weights_mat_
        zero_thresh = 0

    else:
        Pi = model.weights_mat_
        zero_thresh = 0

    # get block s of the matrix
    block_mat = get_block_mat(Pi > zero_thresh)
    block_summary, Pi_block = community_summary(Pi, zero_thresh=zero_thresh)
    n_blocks_est = block_summary['n_communities']
    # block_labels = ['block_{}'.format(b + 1) for b in range(n_blocks_est)]
    block_labels = (1 + np.arange(n_blocks_est)).astype(str)
    out['block_mat'] = block_mat

    if n_views > 2:
        raise NotImplementedError

    # TODO: check this
    # add cluster names to Pi
    row_names = view_cluster_labels[0]
    col_names = view_cluster_labels[1]
    Pi = pd.DataFrame(Pi, index=row_names, columns=col_names)
    Pi.index.name = dataset_names[0]
    Pi.columns.name = dataset_names[1]
    out['Pi'] = Pi

    # get map from blocks to view clusters
    block2view_clusts = {block_labels[b]: {} for b in range(n_blocks_est)}
    for b in range(n_blocks_est):
        row_mask = block_summary['row_memberships'] == b
        row_view_clust_labels = Pi.index.values[row_mask]
        # row_view_clust_labels = [l.split('_')[-1]  # just get the number
        #                          for l in row_view_clust_labels]

        block2view_clusts[block_labels[b]][dataset_names[0]] = \
            row_view_clust_labels

        col_mask = block_summary['col_memberships'] == b
        col_view_clust_labels = Pi.columns.values[col_mask]
        # col_view_clust_labels = [l.split('_')[-1]  # just get the number
        #                          for l in col_view_clust_labels]
        block2view_clusts[block_labels[b]][dataset_names[1]] = \
            col_view_clust_labels

    # add cluster names to block permuted Pi
    row_names_perm = row_names[np.argsort(block_summary['row_memberships'])]
    col_names_perm = col_names[np.argsort(block_summary['col_memberships'])]
    Pi_block_perm = pd.DataFrame(Pi_block,
                                 index=row_names_perm,
                                 columns=col_names_perm)
    Pi_block_perm.index.name = dataset_names[0]
    Pi_block_perm.columns.name = dataset_names[1]
    out['Pi_block_perm'] = Pi_block_perm
    out['Pi_block_perm_zero_mask'] = Pi_block_perm.values < zero_thresh

    # add block labels to summary data frame
    block_label_info = []
    for k, joint_label in enumerate(joint_cluster_labels):
        view_idxs = model._get_view_clust_idx(k)
        b = block_mat[view_idxs[0], view_idxs[1]]
        if np.isnan(b):
            block_lab = ''
        else:
            block_lab = block_labels[int(b)]
        block_label_info.append({'cluster': joint_label, 'block': block_lab})

    block_label_info = pd.DataFrame(block_label_info).set_index('cluster')
    out['joint_summary']['block'] = block_label_info['block']

    if n_blocks_est > 1:

        # TODO: get block weights
        block_weights = []
        for b in range(n_blocks_est):
            mask = out['block_mat'] == b
            block_weights.append(model.weights_mat_[mask].sum())

        block_weights = pd.Series(block_weights, index=block_labels)
        out['block_weights'] = block_weights

        # get block level predictions
        y_block_pred_no_restr, n_out_of_comm = \
            get_y_comm_pred_out_comm(model, view_data, block_mat)
        y_block_pred_restr = \
            get_y_comm_pred_restrict_comm(model, view_data, block_mat)
        print('n_out_of_comm', n_out_of_comm)
        # print(cluster_report(y_block_pred_restr, y_block_pred)['ars'])
        y_pred_block = y_block_pred_restr

        out['block_summary'] = \
            summarize_mm_clusters(y_pred=y_pred_block,
                                  weights=block_weights,
                                  cluster_labels=block_labels)
        # add block2view cluster labels
        for dn in dataset_names:
            out['block_summary'][dn + '__clusters'] = ''
            for b in range(n_blocks_est):
                bl = block_labels[b]
                x = ' '.join(block2view_clusts[bl][dn])
                out['block_summary'].loc[bl, dn + '__clusters'] = x

        # format for pandas
        y_pred_block = [block_labels[y] for y in y_pred_block]
        y_pred_block = pd.Series(y_pred_block,
                                 index=sample_names,
                                 name='block')

        out['y_pred_block'] = y_pred_block

        # metadata comparisons
        if vars2compare is not None:
            _vars2compare = vars2compare.loc[y_pred_block.index, :]
            block_comparison = BlockBlock(**compare_kws)
            block_comparison.fit(y_pred_block,
                                 _vars2compare).correct_multi_tests()

            out['block_comparisons'] = block_comparison

        # survival
        if survival_df is not None:
            out['block_survival'] = get_survival(survival_df=survival_df,
                                                 y=y_pred_block)

    ##########################
    # process data to return #
    ##########################
    info = {}
    info['n_samples'] = n_samples
    info['n_views'] = n_views
    info['n_joint_comp'] = n_joint_comp
    info['n_view_comp'] = n_view_comp
    info['n_blocks'] = n_blocks_est
    info['joint_cluster_labels'] = joint_cluster_labels
    info['view_cluster_labels'] = view_cluster_labels
    info['view_cl_idxs'] = view_cl_idxs
    info['view_cl_labels'] = view_cl_labels
    info['block_labels'] = block_labels
    info['zero_thresh'] = zero_thresh
    info['block_summary'] = block_summary
    out['info'] = info

    return out
Esempio n. 12
0
def load_results(sim_name, select_metric='bic'):
    """
    Loads the
    Parameters
    -----------
    sim_name: str
        Name of the simulation.

    select_metric: str
        Used to pick the best model if there are ties for log_pen_mvmm_at_truth.
    """

    results = load(os.path.join(Paths().out_data_dir, sim_name,
                   'simulation_results'))

    extra_data = load(os.path.join(Paths().out_data_dir, sim_name,
                      'extra_data_mc_0'))

    # clustering results
    clust_results = results['clust_results']
    clust_results = clust_results.reset_index(drop=True)

    # TODO: drop this
    # clust_results = clust_results.\
    #     rename(columns={'est_n_communities': 'n_blocks_est',
    #                     'true_n_communities': 'n_blocks_true'})

    int_cols = ['mc_index', 'best_tuning_idx', 'n_comp_est',
                'n_comp_resid', 'n_comp_tot_est', 'n_samples', 'tune_idx']
    clust_results[int_cols] = clust_results[int_cols].astype(int)

    # need to do these below  because of nans
    bd_int_cols = ['n_blocks_est', 'n_blocks_true']

    # block diagonal summary
    bd_summary = results['bd_summary']
    int_cols = ['n_nonzero_entries', 'n_communities', 'tune_idx',
                'n_samples', 'mc_index']
# 'n_connected_rows', 'n_connected_cols']

    bd_summary[int_cols] = bd_summary[int_cols].astype(int)

    log_pen_bd_summary = bd_summary.\
        query("model_name == 'log_pen_mvmm'").\
        set_index(['n_samples', 'mc_index', 'tune_idx'])

    bd_mvmm_bd_summary = bd_summary.\
        query("model_name == 'bd_mvmm'").\
        set_index(['n_samples', 'mc_index', 'tune_idx'])

    # TODO: uncomment after rerunning
    sp_mvmm_bd_summary = bd_summary.\
        query("model_name == 'sp_mvmm'").\
        set_index(['n_samples', 'mc_index', 'tune_idx'])

    n_samples_tr_seq = results['sim_metadata']['n_samples_tr_seq']

    # # TODO: get this from results when implemented
    # zero_thresh = .1 / (Pi_true.shape[0] * Pi_true.shape[1])
    zero_thresh = results['metadata'][0]['zero_thresh']

    # true number of components
    n_comp_tot_true = results['metadata'][0]['n_comp_tot']
    n_comp_views_true = results['metadata'][0]['n_view_components']

    fit_models = extra_data['fit_models']

    data = {}
    data['X_tr'] = extra_data['X_tr']
    data['Y_tr'] = extra_data['Y_tr']
    data['view_params'] = extra_data['view_params']
    data['Pi_true'] = extra_data['Pi']

    Pi_true = extra_data['Pi']
    true_summary, _ = community_summary(Pi_true)
    n_blocks_true = true_summary['n_communities']

    pi_true_summary = {'n_comp_tot_true': n_comp_tot_true,
                       'n_comp_views_true': n_comp_views_true,
                       'n_blocks_true': n_blocks_true,
                       'Pi': Pi_true,
                       'summary': true_summary}

    # sim_summary = get_sim_summary(results)

    models2exclude = []

    ###########
    # Log pen #
    ###########

    log_pen_mvmm_df = clust_results.\
        query("model == 'log_pen_mvmm' & dataset == 'full' & view == 'both'")

    if log_pen_mvmm_df.shape[0] == 0:
        models2exclude.append('log_pen_mvmm')

    if 'log_pen_mvmm' not in models2exclude:

        log_pen_mvmm_df = add_model_selection_by_measures(log_pen_mvmm_df)

        log_pen_mvmm_df[bd_int_cols] = log_pen_mvmm_df[bd_int_cols].astype(int)

        log_pen_mvmm_df = log_pen_mvmm_df.set_index(['n_samples',
                                                     'mc_index', 'tune_idx'])

        log_pen_mvmm_df = pd.concat([log_pen_mvmm_df,
                                     log_pen_bd_summary], axis=1).reset_index()

        vals, param_name = extract_tuning_param_vals(log_pen_mvmm_df)
        log_pen_mvmm_df['tune__' + param_name] = vals

        # log_pen_mvmm_df_all = log_pen_mvmm_df.copy()

        # TODO: do we want this?
        # for lambd values which give the same n_est_comp,
        # keep only the best one
        # log_pen_mvmm_df = get_best_tune_expers(log_pen_mvmm_df,
        #                                        by='n_comp_est',
        #                                        measure=select_metric,
        #                                        min_good=True)

    ###########
    # BD MVMM #
    ###########

    bd_mvmm_df = clust_results.\
        query("model == 'bd_mvmm' & dataset == 'full' & view == 'both'")

    if bd_mvmm_df.shape[0] == 0:
        models2exclude.append('bd_mvmm')

    if 'bd_mvmm' not in models2exclude:

        bd_mvmm_df = add_model_selection_by_measures(bd_mvmm_df)

        bd_mvmm_df[bd_int_cols] = bd_mvmm_df[bd_int_cols].astype(int)

        bd_mvmm_df['n_blocks_req'] = \
            bd_mvmm_df.loc[:, 'tuning_param_values'].\
            apply(lambda x: x['n_blocks'])

        bd_mvmm_df['n_blocks_req'] = bd_mvmm_df['n_blocks_req'].astype(int)

        bd_mvmm_df = bd_mvmm_df.set_index(['n_samples', 'mc_index',
                                           'tune_idx'])

        bd_mvmm_df = pd.concat([bd_mvmm_df,
                                bd_mvmm_bd_summary], axis=1).reset_index()

        # TODO: this is a hack -- come up with a better solution
        # bd_mvmm_df['n_comp_est'] = bd_mvmm_df['n_nonzero_entries']

        vals, param_name = extract_tuning_param_vals(bd_mvmm_df)
        bd_mvmm_df['tune__' + param_name] = vals

    ####################
    # spetral pen MVMM #
    ####################
    sp_mvmm_df = clust_results.\
        query("model == 'sp_mvmm' & dataset == 'full' & view == 'both'")

    if sp_mvmm_df.shape[0] == 0:
        models2exclude.append('sp_mvmm')

    if 'sp_mvmm' not in models2exclude:

        sp_mvmm_df = add_model_selection_by_measures(sp_mvmm_df)

        sp_mvmm_df[bd_int_cols] = sp_mvmm_df[bd_int_cols].astype(int)

        sp_mvmm_df = sp_mvmm_df.set_index(['n_samples', 'mc_index',
                                           'tune_idx'])

        sp_mvmm_df = pd.concat([sp_mvmm_df,
                                sp_mvmm_bd_summary], axis=1).reset_index()

        # sp_mvmm_df['n_comp_est'] = sp_mvmm_df['n_nonzero_entries'].astype(int)

        vals, param_name = extract_tuning_param_vals(sp_mvmm_df)
        sp_mvmm_df['tune__' + param_name] = vals

        # sp_mvmm_df_all = sp_mvmm_df.copy()

        # TODO: I don't think we want this
        # sp_mvmm_df = get_best_tune_expers(sp_mvmm_df, by='n_blocks_est',
        #                                   measure=select_metric,
        #                                   min_good=True)

    #################
    # others models #
    #################
    full_df = clust_results.\
        query("model == 'full_mvmm' & dataset == 'full' & view == 'both'")

    full_df = add_model_selection_by_measures(full_df)

    cat_gmm_df = clust_results.\
        query("model == 'gmm_cat' & dataset == 'full' & view == 'both'")

    cat_gmm_df = add_model_selection_by_measures(cat_gmm_df)

    view_0_gmm_df = clust_results.\
        query("model == 'marginal_view_0' & dataset == 'view' & view == 0")

    view_1_gmm_df = clust_results.\
        query("model == 'marginal_view_1' & dataset == 'view' & view == 1")

    log_pen_view_0_df = clust_results.\
        query("model == 'log_pen_mvmm' & dataset == 'view' & view == 0")

    log_pen_view_1_df = clust_results.\
        query("model == 'log_pen_mvmm' & dataset == 'view' & view == 1")

    bd_mvmm_view_0_df = clust_results.\
        query("model == 'bd_mvmm' & dataset == 'view' & view == 0")

    bd_mvmm_view_1_df = clust_results.\
        query("model == 'bd_mvmm' & dataset == 'view' & view == 1")

    sp_mvmm_view_0_df = clust_results.\
        query("model == 'sp_mvmm' & dataset == 'view' & view == 0")

    sp_mvmm_view_1_df = clust_results.\
        query("model == 'sp_mvmm' & dataset == 'view' & view == 1")

    #######################################
    # results at true parameter settings #
    ######################################
    cat_gmm_df_at_truth = cat_gmm_df. \
        query("n_comp_tot_est == {}".format(n_comp_tot_true))

    log_pen_mvmm_df_at_truth = log_pen_mvmm_df.\
        query("n_comp_est == {}".format(n_comp_tot_true))

    log_pen_mvmm_df_at_truth = get_best_tune_expers(log_pen_mvmm_df_at_truth,
                                                    by='n_comp_est',
                                                    measure=select_metric)

    bd_mvmm_df_at_truth = bd_mvmm_df.\
        query("n_blocks_est == {}".format(n_blocks_true))

    sp_mvmm_df_at_truth = get_best_expers_at_truthish(sp_mvmm_df,
                                                      group_var='n_blocks_est',
                                                      true_val=n_blocks_true)

    # break ties with BIC best
    # this is not needed for sp by block
    # sp_mvmm_df_at_truth = get_best_tune_expers(sp_mvmm_df_at_truth,
    #                                            by='n_blocks_est',
    #                                            measure=select_metric)

    # TODO: subset out to best BIC of these
    # TODO: what if no one gives truth?
    # sp_mvmm_df_at_truth = sp_mvmm_df.\
    #     query("n_blocks_est == {}".format(n_blocks_true))

    model_dfs = {'cat_gmm': cat_gmm_df,
                 'full_mvmm': full_df,
                 'log_pen_mvmm': log_pen_mvmm_df,
                 'bd_mvmm': bd_mvmm_df,
                 'sp_mvmm': sp_mvmm_df,
                 'view_0_gmm': view_0_gmm_df,
                 'view_1_gmm': view_1_gmm_df,
                 'log_pen_view_0': log_pen_view_0_df,
                 'log_pen_view_1': log_pen_view_1_df,
                 'bd_mvmm_view_0': bd_mvmm_view_0_df,
                 'bd_mvmm_view_1': bd_mvmm_view_1_df,
                 'sp_mvmm_view_0_df': sp_mvmm_view_0_df,
                 'sp_mvmm_view_1_df': sp_mvmm_view_1_df}

    # TODO: get rid of this
    # model_dfs_all_tune_vals = {'log_pen_mvmm': log_pen_mvmm_df_all,
    #                            'sp_mvmm': sp_mvmm_df_all}

    model_dfs_at_truth = {'full_mvmm': full_df,
                          'cat_gmm': cat_gmm_df_at_truth,
                          'log_pen_mvmm': log_pen_mvmm_df_at_truth,
                          'bd_mvmm': bd_mvmm_df_at_truth,
                          'sp_mvmm': sp_mvmm_df_at_truth}  # TODO

    return results, model_dfs, model_dfs_at_truth, \
        fit_models, pi_true_summary, \
        n_samples_tr_seq, zero_thresh, data
Esempio n. 13
0
def run_sim_from_configs(clust_param_config,
                         grid_means,
                         pi_dist,
                         pi_config,
                         single_view_config,
                         mvmm_config,
                         n_samples_tr,
                         n_samples_tst,
                         gmm_pm,
                         n_blocks_pm,
                         reg_covar_mult=1e-2,
                         to_exclude=None,
                         data_seed=None,
                         mc_index=None,
                         args=None,
                         save_fpath=None):

    input_config = locals()

    print('Simulation starting at',
          datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

    print(args)

    start_time = time()

    # cluster parameters are different for each MC iteration
    data_dist, Pi, view_params = \
        get_data_dist(clust_param_config=clust_param_config,
                      grid_means=grid_means,
                      pi_dist=pi_dist, pi_config=pi_config)

    n_comp_tot, n_view_components = get_n_comp(Pi)
    n_blocks_true = community_summary(Pi)[0]['n_communities']

    # cat and view GMMs sequences to search over
    single_view_config['cat_n_comp'] = get_n_comp_seq(n_comp_tot, gmm_pm)

    single_view_config['view_n_comp'] = \
        [get_n_comp_seq(n_view_components[v], gmm_pm) for v in range(2)]

    # n block sequence to search over
    lbd = max(1, n_blocks_true - n_blocks_pm)
    ubd = min(n_blocks_true + n_blocks_pm, min(n_view_components))
    nb_seq = np.arange(lbd, ubd + 1)

    mvmm_config['n_blocks'] = nb_seq  # 'default'

    # get models
    # models = models_from_config(n_view_components=n_view_components,
    #                             n_comp_tot=n_comp_tot,
    #                             n_blocks=n_blocks_true,
    #                             **model_config)

    models = {
        **get_single_view_models(**single_view_config),
        **get_mvmms(n_view_components=n_view_components, **mvmm_config), 'clf':
        LinearDiscriminantAnalysis()
    }

    # set oracle model
    _view_params = format_view_params(view_params, covariance_type='full')

    models['oracle'] = set_mvmm_from_params(view_params=_view_params,
                                            Pi=Pi,
                                            covariance_type='full')
    zero_thresh = .01 / (n_view_components[0] * n_view_components[1])

    # log_dir = os.path.dirname(save_fpath)
    # log_dir = os.path.join(save_fpath, 'log')
    # os.makedirs(log_dir, exist_ok=True)
    log_dir = os.path.join(os.path.dirname(save_fpath), 'log')
    os.makedirs(log_dir, exist_ok=True)
    log_fname = os.path.basename(save_fpath) + '_simulation_progress.txt'
    log_fpath = os.path.join(log_dir, log_fname)

    # run simulation
    clust_results, clf_results, fit_models,\
        bd_summary, Pi_empirical, tr_data, runtimes = \
        run_sim(models=models,
                data_dist=data_dist, Pi=Pi, view_params=view_params,
                n_samples_tr=n_samples_tr, data_seed=data_seed,
                mc_index=mc_index, n_samples_tst=n_samples_tst,
                zero_thresh=zero_thresh,
                reg_covar_mult=reg_covar_mult,
                to_exclude=to_exclude,
                log_fpath=log_fpath)

    log_pen_param_grid = fit_models['log_pen_mvmm'].param_grid_
    bd_param_grid = fit_models['bd_mvmm'].param_grid_
    sp_param_grid = fit_models['sp_mvmm'].param_grid_

    metadata = {
        'n_samples_tr': n_samples_tr,
        'n_comp_tot': n_comp_tot,
        'mc_index': mc_index,
        'n_view_components': n_view_components,
        'config': input_config,
        'args': args,
        'log_pen_param_grid': log_pen_param_grid,
        'bd_param_grid': bd_param_grid,
        'sp_param_grid': sp_param_grid,
        'zero_thresh': zero_thresh,
        'tot_runtime': time() - start_time,
        'fit_runtimes': runtimes
    }

    print('Simulation finished at {} and took {} seconds'.format(
        datetime.now().strftime("%d/%m/%Y %H:%M:%S"), metadata['tot_runtime']))

    if save_fpath is not None:
        print('saving file at {}'.format(save_fpath))

        dump(
            {
                'clust_results': clust_results,
                'clf_results': clf_results,
                'bd_summary': bd_summary,
                'metadata': metadata,
                'config': input_config
            }, save_fpath)

        # save some extra data for one MC repition
        if mc_index == 0:
            save_dir = os.path.dirname(save_fpath)
            fpath = os.path.join(
                save_dir, 'extra_data_mc_0__n_samples_{}'.format(n_samples_tr))
            dump(
                {
                    'Pi': Pi,
                    'view_params': view_params,
                    'tr_data': tr_data,
                    'fit_models': fit_models,
                    'Pi_empirical': Pi_empirical
                }, fpath)

    return {
        'clust_results': clust_results,
        'clf_results': clf_results,
        'bd_summary': bd_summary,
        'metadata': metadata,
        'fit_models': fit_models,
        'Pi': Pi,
        'view_params': view_params
    }
Esempio n. 14
0
    linewidths=.2,
    cbar=False,
    mask=Pi_true.T == 0)
plt.xlabel("First view clusters")
plt.ylabel("Second view clusters")
plt.title("True Pi matrix")
plt.xticks(
    np.arange(n_view_components[0]) + .5, np.arange(1,
                                                    n_view_components[0] + 1))
plt.yticks(
    np.arange(n_view_components[1]) + .5, np.arange(1,
                                                    n_view_components[1] + 1))
plt.savefig('obs_data_and_true_pi.png', dpi=200, bbox_inches='tight')

D_est = mvmm.final_.bd_weights_
bd_summary, D_est_bd_perm = community_summary(
    D_est, zero_thresh=mvmm.final_.zero_thresh)
# Estimated D matrix
plt.figure(figsize=(8, 8))
sns.heatmap(
    D_est_bd_perm.T,  # transpose so the first view is on the rows
    annot=True,
    cmap='Blues',
    vmin=0,
    linewidths=.2,
    cbar=False,
    mask=D_est_bd_perm.T == 0)
plt.xlabel("First view clusters")
plt.ylabel("Second view clusters")
plt.title("Estimated D matrix (permuted to reveal block diagonal structure)")
plt.xticks(
    np.arange(n_view_components[0]) + .5, np.arange(1,
Esempio n. 15
0
    def _em_adaptive_pen(self, X, random_state=None):
        """
        Run the EM algorithm to convergence. Increase spectral penalty
        value until we reach the requested number of blocks.
        """

        adapt_pen_history = {'n_blocks_est': [],
                             'opt_data': []}
        # TODO-warning: if random_state is None then we won't initialize
        # parameters to the same place when alpha is decreased

        self.__mode == 'lap_pen'

        # set lap pen to base
        self._initialize_eval_pen(X=X)
        eval_pen_init = deepcopy(self.eval_pen_)

        if self.n_blocks is None or self.n_blocks == 1:
            _n_pen_tries = 1
        else:
            _n_pen_tries = self.n_pen_tries

        for t in range(_n_pen_tries):

            if self.verbosity >= 1:
                time = datetime.now().strftime("%H:%M:%S")
                print('Trying eigenvalue penalty ({}/{}) at {}'.
                      format(t + 1, self.n_pen_tries, time))

            # initial_params = deepcopy(self._get_parameters())

            # run EM loop
            adpt_params, adpt_opt_data = self._em_loop(X=X)

            # check if we have  found enough blocks
            comm_summary = community_summary(adpt_params['bd_weights'],
                                             zero_thresh=self.zero_thresh)[0]

            n_blocks_est = comm_summary['n_communities']

            adpt_opt_data['n_blocks_est'] = n_blocks_est
            adpt_opt_data['eval_pen'] = deepcopy(self.eval_pen_)
            adapt_pen_history['n_blocks_est'].append(n_blocks_est)

            if self.history_tracking >= 1:
                adapt_pen_history['opt_data'].append(deepcopy(adpt_opt_data))

            # TODO: what to do about initialization
            if self.n_blocks is not None and n_blocks_est < self.n_blocks:
                # too few blocks, increase eval penalty
                self.update_eval_pen(increase=True)

            elif self.n_blocks is not None and n_blocks_est > self.n_blocks:
                # too many blocks, decrease eval penalty
                self.update_eval_pen(increase=False)

            else:
                break

        adpt_opt_data['adapt_pen_history'] = adapt_pen_history

        # check if we successed in getting the number of requested blocks
        if self.n_blocks is not None:
            success = n_blocks_est == self.n_blocks
        else:
            success = True

        adpt_opt_data['eval_pen_init'] = eval_pen_init
        adpt_opt_data['success'] = success
        adpt_opt_data['n_blocks_est'] = n_blocks_est

        ######################################
        # fine tune block diagonal structure #
        ######################################

        # opt history
        opt_data = {'adpt_opt_data': adpt_opt_data,
                    'success': success,
                    'n_blocks_est': n_blocks_est}

        # fine tune with fixed block diagonal structure
        if self.fine_tune_n_steps is not None and success \
                and self.n_blocks != 1:

            # re-set parameters
            max_n_steps = deepcopy(self.max_n_steps)
            self.max_n_steps = deepcopy(self.fine_tune_n_steps)
            self.__mode = 'fine_tune_bd'

            # True/False array of zero elements
            self.zero_mask_bd_ = \
                ~get_nonzero_block_mask(self.bd_weights_,
                                        tol=self.zero_thresh)[0]

            params, ft_opt_data = self._em_loop(X=X)

            # save data
            opt_data['adpt_params'] = adpt_params
            opt_data['ft_opt_data'] = ft_opt_data
            opt_data['loss_val'] = ft_opt_data['loss_val']

            # put back original parameters
            self.__mode = 'lap_pen'
            self.max_n_steps = max_n_steps

        else:
            # loss value should be negative log lik
            opt_data['loss_val'] = adpt_opt_data['history']['obs_nll'][-1]
            opt_data['adpt_params'] = None  # no need to save these again
            opt_data['ft_opt_data'] = None

            params = adpt_params

        return params, opt_data
Esempio n. 16
0
def run_sim_from_configs(clust_param_config,
                         pi_dist,
                         pi_config,
                         n_samples_tr,
                         n_samples_tst,
                         view_gmm_config,
                         cat_gmm_config,
                         gmm_plus_minus,
                         mvmm_model,
                         full_mvmm_config,
                         base_gmm_config,
                         tune_config,
                         start_config,
                         final_config,
                         two_stage_config,
                         n_jobs=None,
                         data_seed=None,
                         mc_index=None,
                         save_fpath=None):

    input_config = locals()

    print('Simulation starting at',
          datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

    print(input_config)

    start_time = time()

    # cluster parameters are different for each MC iteration
    data_dist, Pi, view_params = \
        get_data_dist(clust_param_config=clust_param_config,
                      pi_dist=pi_dist, pi_config=pi_config)

    n_comp_tot, n_view_components = get_n_comp(Pi)
    n_blocks_true = community_summary(Pi)[0]['n_communities']

    ########
    # MVMM #
    ########

    # number of view components for MVMM
    mvmm_n_view_components = tune_config['n_view_components']
    if mvmm_n_view_components == 'true':
        mvmm_n_view_components = n_view_components

    # Full MVMM
    full_mvmm = get_full_mvmm(mvmm_n_view_components,
                              gmm_config=base_gmm_config,
                              config=full_mvmm_config)

    # Two stage estimators
    if mvmm_model == 'log_pen':

        ts_gs_mvmm = \
            get_mvmm_log_pen_gs(n_view_components=mvmm_n_view_components,
                                gmm_config=base_gmm_config,
                                full_config=start_config,
                                log_pen_config=final_config,
                                two_stage_config=two_stage_config,
                                mult_values=tune_config['mult_values'],
                                n_jobs=n_jobs)

    elif mvmm_model == 'block_diag':
        pm = int(tune_config['n_blocks_pm'])

        n_blocks_tune = np.arange(max(2, n_blocks_true - pm),
                                  n_blocks_true + pm + 1)
        ts_gs_mvmm = \
            get_mvmm_block_diag_gs(n_view_components=mvmm_n_view_components,
                                   gmm_config=base_gmm_config,
                                   full_config=start_config,
                                   bd_config=final_config,
                                   two_stage_config=two_stage_config,
                                   n_blocks=n_blocks_tune,
                                   n_jobs=n_jobs)

    #############
    # cat GMM #
    #############
    n_components_seq = get_n_comp_seq(n_comp_tot, gmm_plus_minus)
    cat_gmm = get_gmm_gs(n_components_seq,
                         gmm_config=cat_gmm_config,
                         n_jobs=n_jobs)

    #############
    # view GMMs #
    #############
    view_gmms = []
    for v in range(len(n_view_components)):
        n_components_seq = get_n_comp_seq(n_view_components[v], gmm_plus_minus)
        view_gmms.append(
            get_gmm_gs(n_components_seq,
                       gmm_config=view_gmm_config,
                       n_jobs=n_jobs))

    # classifier
    clf = LinearDiscriminantAnalysis()

    zero_thresh = .01 / (n_view_components[0] * n_view_components[1])

    # run simulation
    clust_results, clf_results, fit_models,\
        bd_results, Pi_empirical, runtimes = \
        run_sim(full_mvmm=full_mvmm, ts_gs_mvmm=ts_gs_mvmm, cat_gmm=cat_gmm,
                view_gmms=view_gmms, clf=clf,
                data_dist=data_dist, Pi=Pi, view_params=view_params,
                n_samples_tr=n_samples_tr, data_seed=data_seed,
                mc_index=mc_index, n_samples_tst=n_samples_tst,
                zero_thresh=zero_thresh)

    mvmm_param_grid = fit_models['ts_gs_mvmm'].param_grid_

    metadata = {
        'n_samples_tr': n_samples_tr,
        'n_comp_tot': n_comp_tot,
        'n_view_components': n_view_components,
        'config': input_config,
        'mvmm_param_grid': mvmm_param_grid,
        'zero_thresh': zero_thresh,
        'tot_runtime': time() - start_time,
        'fit_runtimes': runtimes
    }

    print('Simulation finished at {} and took {} seconds'.format(
        datetime.now().strftime("%d/%m/%Y %H:%M:%S"), metadata['tot_runtime']))

    if save_fpath is not None:
        print('saving file at {}'.format(save_fpath))

        dump(
            {
                'clust_results': clust_results,
                'clf_results': clf_results,
                'bd_results': bd_results,
                'metadata': metadata,
                'config': input_config
            }, save_fpath)

        # save some extra data for one MC repition
        if mc_index == 0:
            save_dir = os.path.dirname(save_fpath)
            fpath = os.path.join(
                save_dir, 'extra_data_mc_0__n_samples_{}'.format(n_samples_tr))
            dump(
                {
                    'Pi': Pi,
                    'fit_models': fit_models,
                    'Pi_empirical': Pi_empirical
                }, fpath)

    return {
        'clust_results': clust_results,
        'clf_results': clf_results,
        'bd_results': bd_results,
        'metadata': metadata,
        'fit_models': fit_models,
        'Pi': Pi,
        'view_params': view_params
    }