Exemplo n.º 1
0
    def variance_strategy(iteration_data):
        # === CALCULATE NUMBER OF SAMPLES IN EACH INTERVENTION
        n = iteration_data.n_samples / (iteration_data.n_batches *
                                        iteration_data.max_interventions)
        if int(n) != n:
            raise ValueError(
                'n_samples / (n_batches * max interventions) must be an integer'
            )

        # === DEFINE PATHS FOR FILES WHICH WILL HOLD THE TEMPORARY DATA
        samples_path = os.path.join(iteration_data.batch_folder, 'samples.csv')
        interventions_path = os.path.join(iteration_data.batch_folder,
                                          'interventions.csv')
        dags_path = os.path.join(iteration_data.batch_folder, 'TEMP_DAGS/')

        # === SAVE DATA, THEN CALL R CODE WITH DATA TO GET DAG SAMPLES
        graph_utils._write_data(iteration_data.current_data, samples_path,
                                interventions_path)
        graph_utils.run_gies_boot(n_boot,
                                  samples_path,
                                  interventions_path,
                                  dags_path,
                                  delete=True)
        amats, dags = graph_utils._load_dags(dags_path, delete=True)
        dag_target_parents = [dag.parents[target] for dag in dags]
        if len(dags) != n_boot:
            raise RuntimeError(
                'Correct number of DAGs not saved, check R code')

        # === SAVE SAMPLED DAGS FROM R FOR FUTURE REFERENCE
        for d, amat in enumerate(amats):
            np.save(os.path.join(iteration_data.batch_folder, 'dag%d.npy' % d),
                    amat)

        # === CALCULATE PARENT SHRINKAGE SCORES
        parent_counts = {node: 0 for node in dags[0].nodes}
        for dag, target_parents in zip(dags, dag_target_parents):
            for p in target_parents:
                parent_counts[p] += 1
        parent_probs = {p: c / len(dags) for p, c in parent_counts.items()}
        parent_shrinkage_scores = {
            p: graph_utils.probability_shrinkage(prob)
            for p, prob in parent_probs.items()
        }

        # === CREATE FUNCTION TO SCORE INTERVENTIONS, GIVEN THESE PARENT SHRINKAGE SCORE
        var_score_fn = create_var_score_fn(parent_shrinkage_scores, target,
                                           amats, node_vars, iv_strengths)

        # === POSSIBLE NODES TO INTERVENE ON: ALL NODES EXCEPT TARGET
        p = amats[0].shape[0]
        iv_family = {iv for iv in range(p) if iv != target}

        # === GREEDILY SELECT INTERVENTIONS
        interventions = greedy_iv(var_score_fn, iv_family,
                                  iteration_data.max_interventions)
        selected_interventions = {int(n) for n in interventions}

        return selected_interventions
Exemplo n.º 2
0
def collect_dags(batch_folder, current_data, n_boot, save_dags=False):
    # === DEFINE PATHS FOR FILES WHICH WILL HOLD THE TEMPORARY DATA
    samples_path = os.path.join(batch_folder, 'samples.csv')
    interventions_path = os.path.join(batch_folder, 'interventions.csv')
    dags_path = os.path.join(batch_folder, 'TEMP_DAGS/')

    # === SAVE DATA, THEN CALL R CODE WITH DATA TO GET DAG SAMPLES
    graph_utils._write_data(current_data, samples_path, interventions_path)
    graph_utils.run_gies_boot(n_boot, samples_path, interventions_path, dags_path, delete=True)
    amats, dags = graph_utils._load_dags(dags_path, delete=True)
    if len(dags) != n_boot:
        raise RuntimeError('Correct number of DAGs not saved, check R code')

    if save_dags:
        for d, amat in enumerate(amats):
            np.save(os.path.join(batch_folder, 'dag%d.npy' % d), amat)

    return dags
Exemplo n.º 3
0
def simulate(strategy, simulator_config, gdag, strategy_folder, num_bootstrap_dags_final=100, save_gies=True):
    if os.path.exists(os.path.join(strategy_folder, 'samples')):
        return

    # === SAVE SIMULATION META-INFORMATION
    os.makedirs(strategy_folder, exist_ok=True)
    simulator_config.save(strategy_folder)

    # === SAMPLE SOME OBSERVATIONAL DATA TO START WITH
    n_nodes = len(gdag.nodes)
    all_samples = {i: np.zeros([0, n_nodes]) for i in range(n_nodes)}
    all_samples[-1] = gdag.sample(simulator_config.starting_samples)
    precision_matrix = np.linalg.inv(all_samples[-1].T @ all_samples[-1] / len(all_samples[-1]))

    # === GET GIES SAMPLES GIVEN JUST OBSERVATIONAL DATA
    if save_gies:
        initial_samples_path = os.path.join(strategy_folder, 'initial_samples.csv')
        initial_interventions_path = os.path.join(strategy_folder, 'initial_interventions')
        initial_gies_dags_path = os.path.join(strategy_folder, 'initial_dags/')
        graph_utils._write_data(all_samples, initial_samples_path, initial_interventions_path)
        graph_utils.run_gies_boot(num_bootstrap_dags_final, initial_samples_path, initial_interventions_path, initial_gies_dags_path)
        amats, dags = graph_utils._load_dags(initial_gies_dags_path, delete=True)
        for d, amat in enumerate(amats):
            np.save(os.path.join(initial_gies_dags_path, 'dag%d.npy' % d), amat)

    # === SPECIFY INTERVENTIONAL DISTRIBUTIONS BASED ON EACH NODE'S STANDARD DEVIATION
    intervention_set = list(range(n_nodes))
    if simulator_config.intervention_type == 'node-variance':
        interventions = [
            cd.BinaryIntervention(
                intervention1=cd.ConstantIntervention(val=-simulator_config.intervention_strength * std),
                intervention2=cd.ConstantIntervention(val=simulator_config.intervention_strength * std)
            ) for std in np.diag(gdag.covariance) ** .5
        ]
    elif simulator_config.intervention_type == 'constant-all':
        interventions = [
            cd.BinaryIntervention(
                intervention1=cd.ConstantIntervention(val=-simulator_config.intervention_strength),
                intervention2=cd.ConstantIntervention(val=simulator_config.intervention_strength)
            ) for _ in intervention_set
        ]
    elif simulator_config.intervention_type == 'gauss':
        interventions = [
            cd.GaussIntervention(mean=0, variance=simulator_config.intervention_strength) for _ in intervention_set
        ]
    elif simulator_config.intervention_type == 'constant':
        interventions = [
            cd.ConstantIntervention(val=0) for _ in intervention_set
        ]
    else:
        raise ValueError

    if not simulator_config.target_allowed:
        del intervention_set[simulator_config.target]
        del interventions[simulator_config.target]
    print(intervention_set)

    # === RUN STRATEGY ON EACH BATCH
    for batch in range(simulator_config.n_batches):
        print('Batch %d with %s' % (batch, simulator_config))
        batch_folder = os.path.join(strategy_folder, 'dags_batch=%d/' % batch)
        os.makedirs(batch_folder, exist_ok=True)
        iteration_data = IterationData(
            current_data=all_samples,
            max_interventions=simulator_config.max_interventions,
            n_samples=simulator_config.n_samples,
            batch_num=batch,
            n_batches=simulator_config.n_batches,
            intervention_set=intervention_set,
            interventions=interventions,
            batch_folder=batch_folder,
            precision_matrix=precision_matrix
        )
        recommended_interventions = strategy(iteration_data)
        if not sum(recommended_interventions.values()) == iteration_data.n_samples / iteration_data.n_batches:
            raise ValueError('Did not return correct amount of samples')
        rec_interventions_nonzero = {intv_ix for intv_ix, ns in recommended_interventions.items() if ns != 0}
        if simulator_config.max_interventions is not None and len(rec_interventions_nonzero) > simulator_config.max_interventions:
            raise ValueError('Returned too many interventions')

        for intv_ix, nsamples in recommended_interventions.items():
            iv_node = intervention_set[intv_ix]
            new_samples = gdag.sample_interventional({iv_node: interventions[intv_ix]}, nsamples)
            all_samples[iv_node] = np.vstack((all_samples[iv_node], new_samples))

    samples_folder = os.path.join(strategy_folder, 'samples')
    os.makedirs(samples_folder, exist_ok=True)
    for i, samples in all_samples.items():
        np.savetxt(os.path.join(samples_folder, 'intervention=%d.csv' % i), samples)

    # === CHECK THE TOTAL NUMBER OF SAMPLES IS CORRECT
    nsamples_final = sum(all_samples[iv_node].shape[0] for iv_node in intervention_set + [-1])
    if nsamples_final != simulator_config.starting_samples + simulator_config.n_samples:
        raise ValueError('Did not use all samples')

    # === GET GIES SAMPLES GIVEN THE DATA FOR THIS SIMULATION
    if save_gies:
        final_samples_path = os.path.join(strategy_folder, 'final_samples.csv')
        final_interventions_path = os.path.join(strategy_folder, 'final_interventions')
        final_gies_dags_path = os.path.join(strategy_folder, 'final_dags/')
        graph_utils._write_data(all_samples, final_samples_path, final_interventions_path)
        graph_utils.run_gies_boot(num_bootstrap_dags_final, final_samples_path, final_interventions_path, final_gies_dags_path)
        amats, dags = graph_utils._load_dags(final_gies_dags_path, delete=True)
        for d, amat in enumerate(amats):
            np.save(os.path.join(final_gies_dags_path, 'dag%d.npy' % d), amat)
Exemplo n.º 4
0
from utils import graph_utils
import numpy as np

folder = 'data/chain_test10/dags/dag0/entropy-dag-collection,n=1500,b=3,k=1/'
samples_folder = folder + 'samples'
temp_folder = folder + 'temp/'
obs_samples = np.loadtxt(samples_folder + '/intervention=-1.csv')
iv_samples = np.loadtxt(samples_folder + '/intervention=1.csv')
data = {-1: obs_samples, 1: iv_samples}

graph_utils._write_data(data, temp_folder + 'samples.csv',
                        temp_folder + 'interventions.csv')
graph_utils.run_gies_boot(100, temp_folder + 'samples.csv',
                          temp_folder + 'interventions.csv',
                          temp_folder + 'sampled_dags/')
_, sampled_dags = graph_utils._load_dags(temp_folder + 'sampled_dags/',
                                         delete=False)
a = [d.arcs for d in sampled_dags]