def variance_strategy(iteration_data): # === CALCULATE NUMBER OF SAMPLES IN EACH INTERVENTION n = iteration_data.n_samples / (iteration_data.n_batches * iteration_data.max_interventions) if int(n) != n: raise ValueError( 'n_samples / (n_batches * max interventions) must be an integer' ) # === DEFINE PATHS FOR FILES WHICH WILL HOLD THE TEMPORARY DATA samples_path = os.path.join(iteration_data.batch_folder, 'samples.csv') interventions_path = os.path.join(iteration_data.batch_folder, 'interventions.csv') dags_path = os.path.join(iteration_data.batch_folder, 'TEMP_DAGS/') # === SAVE DATA, THEN CALL R CODE WITH DATA TO GET DAG SAMPLES graph_utils._write_data(iteration_data.current_data, samples_path, interventions_path) graph_utils.run_gies_boot(n_boot, samples_path, interventions_path, dags_path, delete=True) amats, dags = graph_utils._load_dags(dags_path, delete=True) dag_target_parents = [dag.parents[target] for dag in dags] if len(dags) != n_boot: raise RuntimeError( 'Correct number of DAGs not saved, check R code') # === SAVE SAMPLED DAGS FROM R FOR FUTURE REFERENCE for d, amat in enumerate(amats): np.save(os.path.join(iteration_data.batch_folder, 'dag%d.npy' % d), amat) # === CALCULATE PARENT SHRINKAGE SCORES parent_counts = {node: 0 for node in dags[0].nodes} for dag, target_parents in zip(dags, dag_target_parents): for p in target_parents: parent_counts[p] += 1 parent_probs = {p: c / len(dags) for p, c in parent_counts.items()} parent_shrinkage_scores = { p: graph_utils.probability_shrinkage(prob) for p, prob in parent_probs.items() } # === CREATE FUNCTION TO SCORE INTERVENTIONS, GIVEN THESE PARENT SHRINKAGE SCORE var_score_fn = create_var_score_fn(parent_shrinkage_scores, target, amats, node_vars, iv_strengths) # === POSSIBLE NODES TO INTERVENE ON: ALL NODES EXCEPT TARGET p = amats[0].shape[0] iv_family = {iv for iv in range(p) if iv != target} # === GREEDILY SELECT INTERVENTIONS interventions = greedy_iv(var_score_fn, iv_family, iteration_data.max_interventions) selected_interventions = {int(n) for n in interventions} return selected_interventions
def collect_dags(batch_folder, current_data, n_boot, save_dags=False): # === DEFINE PATHS FOR FILES WHICH WILL HOLD THE TEMPORARY DATA samples_path = os.path.join(batch_folder, 'samples.csv') interventions_path = os.path.join(batch_folder, 'interventions.csv') dags_path = os.path.join(batch_folder, 'TEMP_DAGS/') # === SAVE DATA, THEN CALL R CODE WITH DATA TO GET DAG SAMPLES graph_utils._write_data(current_data, samples_path, interventions_path) graph_utils.run_gies_boot(n_boot, samples_path, interventions_path, dags_path, delete=True) amats, dags = graph_utils._load_dags(dags_path, delete=True) if len(dags) != n_boot: raise RuntimeError('Correct number of DAGs not saved, check R code') if save_dags: for d, amat in enumerate(amats): np.save(os.path.join(batch_folder, 'dag%d.npy' % d), amat) return dags
def simulate(strategy, simulator_config, gdag, strategy_folder, num_bootstrap_dags_final=100, save_gies=True): if os.path.exists(os.path.join(strategy_folder, 'samples')): return # === SAVE SIMULATION META-INFORMATION os.makedirs(strategy_folder, exist_ok=True) simulator_config.save(strategy_folder) # === SAMPLE SOME OBSERVATIONAL DATA TO START WITH n_nodes = len(gdag.nodes) all_samples = {i: np.zeros([0, n_nodes]) for i in range(n_nodes)} all_samples[-1] = gdag.sample(simulator_config.starting_samples) precision_matrix = np.linalg.inv(all_samples[-1].T @ all_samples[-1] / len(all_samples[-1])) # === GET GIES SAMPLES GIVEN JUST OBSERVATIONAL DATA if save_gies: initial_samples_path = os.path.join(strategy_folder, 'initial_samples.csv') initial_interventions_path = os.path.join(strategy_folder, 'initial_interventions') initial_gies_dags_path = os.path.join(strategy_folder, 'initial_dags/') graph_utils._write_data(all_samples, initial_samples_path, initial_interventions_path) graph_utils.run_gies_boot(num_bootstrap_dags_final, initial_samples_path, initial_interventions_path, initial_gies_dags_path) amats, dags = graph_utils._load_dags(initial_gies_dags_path, delete=True) for d, amat in enumerate(amats): np.save(os.path.join(initial_gies_dags_path, 'dag%d.npy' % d), amat) # === SPECIFY INTERVENTIONAL DISTRIBUTIONS BASED ON EACH NODE'S STANDARD DEVIATION intervention_set = list(range(n_nodes)) if simulator_config.intervention_type == 'node-variance': interventions = [ cd.BinaryIntervention( intervention1=cd.ConstantIntervention(val=-simulator_config.intervention_strength * std), intervention2=cd.ConstantIntervention(val=simulator_config.intervention_strength * std) ) for std in np.diag(gdag.covariance) ** .5 ] elif simulator_config.intervention_type == 'constant-all': interventions = [ cd.BinaryIntervention( intervention1=cd.ConstantIntervention(val=-simulator_config.intervention_strength), intervention2=cd.ConstantIntervention(val=simulator_config.intervention_strength) ) for _ in intervention_set ] elif simulator_config.intervention_type == 'gauss': interventions = [ cd.GaussIntervention(mean=0, variance=simulator_config.intervention_strength) for _ in intervention_set ] elif simulator_config.intervention_type == 'constant': interventions = [ cd.ConstantIntervention(val=0) for _ in intervention_set ] else: raise ValueError if not simulator_config.target_allowed: del intervention_set[simulator_config.target] del interventions[simulator_config.target] print(intervention_set) # === RUN STRATEGY ON EACH BATCH for batch in range(simulator_config.n_batches): print('Batch %d with %s' % (batch, simulator_config)) batch_folder = os.path.join(strategy_folder, 'dags_batch=%d/' % batch) os.makedirs(batch_folder, exist_ok=True) iteration_data = IterationData( current_data=all_samples, max_interventions=simulator_config.max_interventions, n_samples=simulator_config.n_samples, batch_num=batch, n_batches=simulator_config.n_batches, intervention_set=intervention_set, interventions=interventions, batch_folder=batch_folder, precision_matrix=precision_matrix ) recommended_interventions = strategy(iteration_data) if not sum(recommended_interventions.values()) == iteration_data.n_samples / iteration_data.n_batches: raise ValueError('Did not return correct amount of samples') rec_interventions_nonzero = {intv_ix for intv_ix, ns in recommended_interventions.items() if ns != 0} if simulator_config.max_interventions is not None and len(rec_interventions_nonzero) > simulator_config.max_interventions: raise ValueError('Returned too many interventions') for intv_ix, nsamples in recommended_interventions.items(): iv_node = intervention_set[intv_ix] new_samples = gdag.sample_interventional({iv_node: interventions[intv_ix]}, nsamples) all_samples[iv_node] = np.vstack((all_samples[iv_node], new_samples)) samples_folder = os.path.join(strategy_folder, 'samples') os.makedirs(samples_folder, exist_ok=True) for i, samples in all_samples.items(): np.savetxt(os.path.join(samples_folder, 'intervention=%d.csv' % i), samples) # === CHECK THE TOTAL NUMBER OF SAMPLES IS CORRECT nsamples_final = sum(all_samples[iv_node].shape[0] for iv_node in intervention_set + [-1]) if nsamples_final != simulator_config.starting_samples + simulator_config.n_samples: raise ValueError('Did not use all samples') # === GET GIES SAMPLES GIVEN THE DATA FOR THIS SIMULATION if save_gies: final_samples_path = os.path.join(strategy_folder, 'final_samples.csv') final_interventions_path = os.path.join(strategy_folder, 'final_interventions') final_gies_dags_path = os.path.join(strategy_folder, 'final_dags/') graph_utils._write_data(all_samples, final_samples_path, final_interventions_path) graph_utils.run_gies_boot(num_bootstrap_dags_final, final_samples_path, final_interventions_path, final_gies_dags_path) amats, dags = graph_utils._load_dags(final_gies_dags_path, delete=True) for d, amat in enumerate(amats): np.save(os.path.join(final_gies_dags_path, 'dag%d.npy' % d), amat)
from utils import graph_utils import numpy as np folder = 'data/chain_test10/dags/dag0/entropy-dag-collection,n=1500,b=3,k=1/' samples_folder = folder + 'samples' temp_folder = folder + 'temp/' obs_samples = np.loadtxt(samples_folder + '/intervention=-1.csv') iv_samples = np.loadtxt(samples_folder + '/intervention=1.csv') data = {-1: obs_samples, 1: iv_samples} graph_utils._write_data(data, temp_folder + 'samples.csv', temp_folder + 'interventions.csv') graph_utils.run_gies_boot(100, temp_folder + 'samples.csv', temp_folder + 'interventions.csv', temp_folder + 'sampled_dags/') _, sampled_dags = graph_utils._load_dags(temp_folder + 'sampled_dags/', delete=False) a = [d.arcs for d in sampled_dags]