def exp_cevae(model="dlvm", n=1000, d=3, p=100, prop_miss=0.1, citcio=False, seed=0, d_cevae=20, n_epochs=402, method="glm", **kwargs): # import here because of differents sklearn version used from cevae_tf import cevae_tf from sklearn.preprocessing import Imputer if model == "lrmf": Z, X, w, y, ps = gen_lrmf(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) elif model == "dlvm": Z, X, w, y, ps = gen_dlvm(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) else: raise NotImplementedError( "Other data generating models not implemented here yet.") X_miss = ampute(X, prop_miss=prop_miss, seed=seed) X_imp = Imputer().fit_transform(X_miss) y0_hat, y1_hat = cevae_tf(X_imp, w, y, d_cevae=d_cevae, n_epochs=n_epochs) # Tau estimated on Zhat=E[Z|X] ps_hat = np.ones(len(y0_hat)) / 2 # res_tau_ols = tau_ols(zhat, w, y) # res_tau_ols_ps = tau_ols_ps(zhat, w, y) #res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method) #res_tau_dr_true_ps = tau_dr(y, w, y0_hat, y1_hat, ps, method) res_tau = np.mean(y1_hat - y0_hat) return res_tau
def exp_mi(model="dlvm", n=1000, d=3, p=100, prop_miss=0.1, citcio=False, seed=0, m=10, d_cevae=20, n_epochs=402, method="glm", **kwargs): if model == "lrmf": Z, X, w, y, ps = gen_lrmf(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) elif model == "dlvm": Z, X, w, y, ps = gen_dlvm(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) else: raise NotImplementedError( "Other data generating models not implemented here yet.") X_miss = ampute(X, prop_miss=prop_miss, seed=seed) tau_dr_mi, tau_ols_mi, tau_ols_ps_mi, tau_resid_mi = tau_mi(X_miss, w, y, m=m, method=method) return tau_dr_mi, tau_ols_mi, tau_ols_ps_mi, tau_resid_mi
def main(unused_argv): # Data generating process parameters exp_parameter_grid = { 'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model], 'citcio': [ False, ], 'nuisance': [ True, ], 'n': [500, 1000, 5000, 10000] if FLAGS.n_observations is None else [FLAGS.n_observations], 'p': [5, 10, 50, 100] if FLAGS.p_ambient is None else [FLAGS.p_ambient], 'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr], 'x_snr': [2.] if FLAGS.x_snr is None else [FLAGS.x_snr], 'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z], 'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z], 'sig_xgivenz': [0.001] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz], 'prop_miss': [0.0, 0.1, 0.3, 0.5] if FLAGS.prop_miss is None else [FLAGS.prop_miss], 'regularize': [False] if FLAGS.regularize is None else [FLAGS.regularize], 'seed': np.arange(FLAGS.n_seeds), } range_d_over_p = [ 0.002, 0.01, 0.1 ] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [ FLAGS.d_over_p ] range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent # MDC parameters range_d_offset = [0, 5] if FLAGS.miwae_d_offset is None else [ FLAGS.miwae_d_offset ] mdc_parameter_grid = { 'mu_prior': [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior], 'sig_prior': [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior], 'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else [FLAGS.miwae_n_samples_zmul], 'learning_rate': [ 0.0001, ] if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate], 'n_epochs': [ 5000, ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs], } # MI parameters range_m = [ 10, ] if FLAGS.n_imputations is None else [FLAGS.n_imputations] # Experiment and output file name output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path logging.get_absl_handler().use_absl_log_file() logging.info('*' * 20) logging.info(f'Starting exp: {FLAGS.exp_name}') logging.info('*' * 20) exp_arguments = [ dict(zip(exp_parameter_grid.keys(), vals)) for vals in itertools.product(*exp_parameter_grid.values()) ] previous_runs = set() if tf.io.gfile.exists(output): with tf.io.gfile.GFile(output, mode='r') as f: reader = csv.DictReader(f) for row in reader: # Note: we need to do this conversion because DictReader creates an # OrderedDict, and reads all values as str instead of bool or int. previous_runs.add( str({ 'model': row['model'], 'citcio': row['citcio'] == 'True', 'n': int(row['n']), 'p': int(row['p']), 'y_snr': float(row['y_snr']), 'x_snr': float(row['x_snr']), 'mu_z': float(row['mu_z']), 'sig_z': float(row['sig_z']), 'prop_miss': float(row['prop_miss']), 'regularize': row['regularize'] == 'True', 'seed': int(row['seed']), 'd': int(row['d']), 'sig_xgivenz': float(row['sig_xgivenz']) })) logging.info('Previous runs') logging.info(previous_runs) for args in exp_arguments: ## For given p, create range for d such that 1 < d < p ## starting with given ratios for d/p if range_d is None: range_d = [ np.maximum(2, int(np.floor(args['p'] * x))) for x in range_d_over_p ] range_d = np.unique( np.array(range_d)[np.array(range_d) < args['p']].tolist()) exp_time = time.time() for args['d'] in range_d: # We only consider cases where latent dimension <= ambient dimension if args['d'] > args['p']: continue res = [] if str(args) in previous_runs: logging.info(f'Skipped {args}') continue else: logging.info(f'running exp with {args}') if args['model'] == "lrmf": Z, X, w, y, ps, mu0, mu1 = gen_lrmf( n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], x_snr=args['x_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed'], sig_xgivenz=args['sig_xgivenz']) elif args['model'] == "dlvm": Z, X, w, y, ps, mu0, mu1 = gen_dlvm( n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed'], mu_z=args['mu_z'], sig_z=args['sig_z'], x_snr=args['x_snr'], sig_xgivenz=args['sig_xgivenz']) X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed']) # On complete data t0 = time.time() if args['nuisance']: tau, nu = exp_complete(Z, X, w, y, args['regularize'], args['nuisance']) else: tau = exp_complete(Z, X, w, y, args['regularize'], args['nuisance']) args['time'] = int(time.time() - t0) row = {'Method': 'Z'} row.update(args) row.update(tau['Z']) print(tau['Z']) if args['nuisance']: row.update( {'ps_hat_mse': mean_squared_error(ps, nu['Z']['ps_hat'])}) row.update( {'y0_hat_mse': mean_squared_error(mu0, nu['Z']['y0_hat'])}) row.update( {'y1_hat_mse': mean_squared_error(mu1, nu['Z']['y1_hat'])}) res.append(row) row = {'Method': 'X'} row.update(args) row.update(tau['X']) if args['nuisance']: row.update( {'ps_hat_mse': mean_squared_error(ps, nu['X']['ps_hat'])}) row.update( {'y0_hat_mse': mean_squared_error(mu0, nu['X']['y0_hat'])}) row.update( {'y1_hat_mse': mean_squared_error(mu1, nu['X']['y1_hat'])}) res.append(row) # Mean-imputation t0 = time.time() if args['nuisance']: tau, nu = exp_mean(X_miss, w, y, args['regularize'], args['nuisance']) else: tau = exp_mean(X_miss, w, y, args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'Mean_imp'} row.update(args) row.update(tau) if args['nuisance']: row.update( {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])}) row.update( {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])}) row.update( {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])}) res.append(row) # Multiple imputation for m in range_m: t0 = time.time() if args['nuisance']: tau, nu = exp_mi(X_miss, w, y, regularize=args['regularize'], m=m, nuisance=args['nuisance']) else: tau = exp_mi(X_miss, w, y, regularize=args['regularize'], m=m) args['time'] = int(time.time() - t0) row = {'Method': 'MI', 'm': m} row.update(args) row.update(tau) if args['nuisance']: row.update( {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])}) row.update( {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])}) row.update( {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])}) res.append(row) # Matrix Factorization t0 = time.time() if args['nuisance']: tau, nu, r, zhat = exp_mf(X_miss, w, y, args['regularize'], args['nuisance'], return_zhat=True) else: tau, r = exp_mf(X_miss, w, y, args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'MF', 'r': r} row.update(args) row.update(tau) if args['nuisance']: row.update( {'ps_hat_mse': mean_squared_error(ps, nu['ps_hat'])}) row.update( {'y0_hat_mse': mean_squared_error(mu0, nu['y0_hat'])}) row.update( {'y1_hat_mse': mean_squared_error(mu1, nu['y1_hat'])}) res.append(row) # MissDeepCausal mdc_parameter_grid['d_miwae'] = [ args['d'] + x for x in range_d_offset ] mdc_arguments = [ dict(zip(mdc_parameter_grid.keys(), vals)) for vals in itertools.product(*mdc_parameter_grid.values()) ] for mdc_arg in mdc_arguments: t0 = time.time() mdc_arg['mu_prior'] = args['mu_z'] session_file = './sessions/' + \ args['model'] + '_'+ \ '_sigXgivenZ' + str(args['sig_xgivenz']) + \ '_n' + str(args['n']) + \ '_p' + str(args['p']) + \ '_d' + str(args['d']) + \ '_ysnr' + str(args['y_snr']) +\ '_xsnr' + str(args['x_snr']) +\ '_propNA' + str(args['prop_miss']) + \ '_seed' + str(args['seed']) session_file_complete = session_file + \ '_dmiwae' + str(mdc_arg['d_miwae']) + \ '_sigprior' + str(mdc_arg['sig_prior']) if args['nuisance']: tau, nu, elbo, zhat, zhat_mul = exp_mdc( X_miss, w, y, d_miwae=mdc_arg['d_miwae'], mu_prior=mdc_arg['mu_prior'], sig_prior=mdc_arg['sig_prior'], num_samples_zmul=mdc_arg['num_samples_zmul'], learning_rate=mdc_arg['learning_rate'], n_epochs=mdc_arg['n_epochs'], regularize=args['regularize'], nuisance=args['nuisance'], return_zhat=True, save_session=True, session_file=session_file, session_file_complete=session_file_complete) else: tau, elbo, zhat, zhat_mul = exp_mdc( X_miss, w, y, d_miwae=mdc_arg['d_miwae'], mu_prior=mdc_arg['mu_prior'], sig_prior=mdc_arg['sig_prior'], num_samples_zmul=mdc_arg['num_samples_zmul'], learning_rate=mdc_arg['learning_rate'], n_epochs=mdc_arg['n_epochs'], regularize=args['regularize'], return_zhat=True, save_session=True, session_file=session_file, session_file_complete=session_file_complete) args['training_time'] = int(time.time() - t0) row = {'Method': 'MDC.process', 'elbo': elbo} row.update(args) row.update(mdc_arg) row.update(tau['MDC.process']) if args['nuisance']: row.update({ 'ps_hat_mse': mean_squared_error(ps, nu['MDC.process']['ps_hat']) }) row.update({ 'y0_hat_mse': mean_squared_error(mu0, nu['MDC.process']['y0_hat']) }) row.update({ 'y1_hat_mse': mean_squared_error(mu1, nu['MDC.process']['y1_hat']) }) res.append(row) row = {'Method': 'MDC.mi', 'elbo': elbo} row.update(args) row.update(mdc_arg) row.update(tau['MDC.mi']) if args['nuisance']: row.update({ 'ps_hat_mse': mean_squared_error(ps, nu['MDC.mi']['ps_hat']) }) row.update({ 'y0_hat_mse': mean_squared_error(mu0, nu['MDC.mi']['y0_hat']) }) row.update({ 'y1_hat_mse': mean_squared_error(mu1, nu['MDC.mi']['y1_hat']) }) res.append(row) log_res(output, res, ['Method'] + list(args.keys()) + l_method_params + l_tau + l_nu) logging.info('........... DONE') logging.info(f'in {time.time() - exp_time} s \n\n') logging.info('*' * 20) logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.') logging.info('*' * 20)
def main(unused_argv): # Data generating process parameters exp_parameter_grid = { 'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model], 'citcio': [False, ], 'n': [1000, 10000, 100000] if FLAGS.n_observations is None else [FLAGS.n_observations], 'p': [10, 100, 1000] if FLAGS.p_ambient is None else [FLAGS.p_ambient], 'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr], 'x_snr': 1.*np.arange(2,20,4) if FLAGS.x_snr is None else [FLAGS.x_snr], 'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z], 'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z], 'sig_xgivenz': ["fixed", ] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz], 'prop_miss': [0.0, 0.1, 0.3, 0.5] if FLAGS.prop_miss is None else [FLAGS.prop_miss], 'regularize': [False] if FLAGS.regularize is None else [FLAGS.regularize], 'seed': np.arange(FLAGS.n_seeds), } range_d_over_p = [0.002, 0.01, 0.1] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [FLAGS.d_over_p] range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent # MDC parameters range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [FLAGS.miwae_d_offset] mdc_parameter_grid = { 'mu_prior': [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior], 'sig_prior': [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior], 'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else [FLAGS.miwae_n_samples_zmul], 'learning_rate': [0.0001,] if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate], 'n_epochs': [5000,] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs], } # Experiment and output file name output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path logging.get_absl_handler().use_absl_log_file() logging.info('*'*20) logging.info(f'Starting exp: {FLAGS.exp_name}') logging.info('*'*20) exp_arguments = [dict(zip(exp_parameter_grid.keys(), vals)) for vals in itertools.product(*exp_parameter_grid.values())] previous_runs = set() if tf.io.gfile.exists(output): with tf.io.gfile.GFile(output, mode='r') as f: reader = csv.DictReader(f) for row in reader: # Note: we need to do this conversion because DictReader creates an # OrderedDict, and reads all values as str instead of bool or int. previous_runs.add(str({ 'model': row['model'], 'citcio': row['citcio'] == 'True', 'n': int(row['n']), 'p': int(row['p']), 'y_snr': float(row['y_snr']), 'x_snr': float(row['x_snr']), 'mu_z': float(row['mu_z']), 'sig_z': float(row['sig_z']), 'prop_miss': float(row['prop_miss']), 'regularize': row['regularize'] == 'True', 'seed': int(row['seed']), 'd': int(row['d']), 'sig_xgivenz': row['sig_xgivenz'] })) logging.info('Previous runs') logging.info(previous_runs) for args in exp_arguments: # For given p, if range_d is not yet specified, # create range for d such that 1 < d < p # starting with given ratios for d/p if range_d is None: range_d = [np.maximum(2, int(np.floor(args['p']*x))) for x in range_d_over_p] range_d = np.unique(np.array(range_d)[np.array(range_d)<args['p']].tolist()) exp_time = time.time() for args['d'] in range_d: # We only consider cases where latent dimension <= ambient dimension if args['d'] > args['p']: continue res = [] if str(args) in previous_runs: logging.info(f'Skipped {args}') continue else: logging.info(f'running exp with {args}') if args['model'] == "lrmf": Z, X, w, y, ps, mu0, mu1 = gen_lrmf(n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], x_snr=args['x_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed']) elif args['model'] == "dlvm": Z, X, w, y, ps, mu0, mu1 = gen_dlvm(n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed'], mu_z=args['mu_z'], sig_z=args['sig_z'], x_snr=args['x_snr'], sig_xgivenz=args['sig_xgivenz']) X_miss = ampute(X, prop_miss = args['prop_miss'], seed = args['seed']) # MIWAE mdc_parameter_grid['d_miwae'] = [args['d']+x for x in range_d_offset] mdc_arguments = [dict(zip(mdc_parameter_grid.keys(), vals)) for vals in itertools.product(*mdc_parameter_grid.values())] for mdc_arg in mdc_arguments: t0 = time.time() mdc_arg['mu_prior']=args['mu_z'] session_file = './sessions/' + \ args['model'] + '_'+ \ args['sig_xgivenz'] + 'Sigma'+ \ '_n' + str(args['n']) + \ '_p' + str(args['p']) + \ '_d' + str(args['d']) + \ '_ysnr' + str(args['y_snr']) +\ '_xsnr' + str(args['x_snr']) +\ '_propNA' + str(args['prop_miss']) + \ '_seed' + str(args['seed']) session_file_complete = session_file + \ '_dmiwae' + str(mdc_arg['d_miwae']) + \ '_sigprior' + str(mdc_arg['sig_prior']) epochs=-1 tmp = glob.glob(session_file_complete+'.*') sess = tf.Session(graph=tf.reset_default_graph()) if len(tmp)>0: continue else: xhat, zhat, zhat_mul, elbo, epochs = miwae_es(X_miss, d_miwae=mdc_arg['d_miwae'], mu_prior=mdc_arg['mu_prior'], sig_prior=mdc_arg['sig_prior'], num_samples_zmul=mdc_arg['num_samples_zmul'], l_rate=mdc_arg['learning_rate'], n_epochs=mdc_arg['n_epochs'], save_session = True, session_file = session_file) with open(session_file_complete + '.pkl', 'wb') as file_data: # Python 3: open(..., 'wb') pickle.dump([xhat, zhat, zhat_mul, elbo, epochs], file_data) logging.info('........... DONE') logging.info(f'in {time.time() - exp_time} s \n\n') logging.info('*'*20) logging.info(f'Exp: {FLAGS.exp_name} successfully ended.') logging.info('*'*20)
l_tau = ['tau_dr', 'tau_ols', 'tau_ols_ps', 'tau_resid'] output = '../results/'+exp_name+'.csv' l_scores = [] for args['model'] in range_model: for args['citcio'] in range_citcio: for args['n'] in range_n: for args['p'] in range_p: range_d = [int(np.floor(args['p']*x)) for x in range_d_over_p] for args['d'] in range_d: for args['prop_miss'] in range_prop_miss: for args['seed'] in range_seed: print(args) if args['model'] == "lrmf": Z, X, w, y, ps = gen_lrmf(n=args['n'], d=args['d'], p=args['p'], citcio = args['citcio'], prop_miss = args['prop_miss'], seed = args['seed']) elif args['model'] == "dlvm": Z, X, w, y, ps = gen_dlvm(n=args['n'], d=args['d'], p=args['p'], citcio = args['citcio'], prop_miss = args['prop_miss'], seed = args['seed']) X_miss = ampute(X, prop_miss = args['prop_miss'], seed = args['seed']) # Complete t0 = time.time() tau = exp_complete(Z, X, w, y) args['time'] = int(time.time() - t0) l_scores.append(np.concatenate((['Z'], list(args.values()), [None]*7, tau['Z']))) l_scores.append(np.concatenate((['X'], list(args.values()), [None]*7, tau['X'])))
def main(unused_argv): # Data generating process parameters exp_parameter_grid = { 'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model], 'citcio': [ False, ], 'n': [1000, 10000, 100000] if FLAGS.n_observations is None else [FLAGS.n_observations], 'p': [10, 100, 1000] if FLAGS.p_ambient is None else [FLAGS.p_ambient], 'snr': [1., 5., 10.] if FLAGS.snr is None else [FLAGS.snr], 'prop_miss': [0.0, 0.1, 0.3, 0.5, 0.7, 0.9] if FLAGS.prop_miss is None else [FLAGS.prop_miss], 'regularize': [False, True] if FLAGS.regularize is None else [FLAGS.regularize], 'seed': np.arange(FLAGS.n_seeds), } range_d_over_p = [0.002, 0.01, 0.1 ] if FLAGS.d_over_p is None else [FLAGS.d_over_p] # MDC parameters range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [ FLAGS.miwae_d_offset ] mdc_parameter_grid = { 'sig_prior': [0.1, 1, 10] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior], 'num_samples_zmul': [50, 500] if FLAGS.miwae_n_samples_zmul is None else [FLAGS.miwae_n_samples_zmul], 'learning_rate': [ 0.0001, ] if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate], 'n_epochs': [ 500, ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs], } # MI parameters range_m = [10, 20, 50 ] if FLAGS.n_imputations is None else [FLAGS.n_imputations] # Experiment and output file name output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output logging.info('*' * 20) logging.info(f'Starting exp: {FLAGS.exp_name}') logging.info('*' * 20) exp_arguments = [ dict(zip(exp_parameter_grid.keys(), vals)) for vals in itertools.product(*exp_parameter_grid.values()) ] previous_runs = set() if tf.io.gfile.exists(output): with tf.io.gfile.GFile(output, mode='r') as f: reader = csv.DictReader(f) for row in reader: # Note: we need to do this conversion because DictReader creates an # OrderedDict, and reads all values as str instead of bool or int. previous_runs.add( str({ 'model': row['model'], 'citcio': row['citcio'] == 'True', 'n': int(row['n']), 'p': int(row['p']), 'snr': float(row['snr']), 'prop_miss': float(row['prop_miss']), 'regularize': row['regularize'] == 'True', 'seed': int(row['seed']), 'd': int(row['d']), })) logging.info('Previous runs') logging.info(previous_runs) for args in exp_arguments: # For given p, create range for d such that 1 < d < p # starting with given ratios for d/p range_d = [ np.maximum(2, int(np.floor(args['p'] * x))) for x in range_d_over_p ] range_d = np.unique( np.array(range_d)[np.array(range_d) < args['p']].tolist()) exp_time = time.time() for args['d'] in range_d: res = [] if str(args) in previous_runs: logging.info(f'Skipped {args}') continue else: logging.info(f'running exp with {args}') if args['model'] == "lrmf": Z, X, w, y, ps = gen_lrmf(n=args['n'], d=args['d'], p=args['p'], y_snr=args['snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed']) elif args['model'] == "dlvm": Z, X, w, y, ps = gen_dlvm(n=args['n'], d=args['d'], p=args['p'], y_snr=args['snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed']) X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed']) # On complete data t0 = time.time() tau = exp_complete(Z, X, w, y, args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'Z'} row.update(args) row.update(tau['Z']) res.append(row) row = {'Method': 'X'} row.update(args) row.update(tau['X']) res.append(row) # Mean-imputation t0 = time.time() tau = exp_mean(X_miss, w, y, args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'Mean_imp'} row.update(args) row.update(tau) res.append(row) # Multiple imputation for m in range_m: t0 = time.time() tau = exp_mi(X_miss, w, y, regularize=args['regularize'], m=m) args['time'] = int(time.time() - t0) row = {'Method': 'MI', 'm': m} row.update(args) row.update(tau) res.append(row) # Matrix Factorization t0 = time.time() tau, r = exp_mf(X_miss, w, y, args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'MF', 'r': r} row.update(args) row.update(tau) res.append(row) # MissDeepCausal mdc_parameter_grid['d_miwae'] = [ args['d'] + x for x in range_d_offset ] mdc_arguments = [ dict(zip(mdc_parameter_grid.keys(), vals)) for vals in itertools.product(*mdc_parameter_grid.values()) ] for mdc_arg in mdc_arguments: t0 = time.time() tau, elbo = exp_mdc( X_miss, w, y, d_miwae=mdc_arg['d_miwae'], sig_prior=mdc_arg['sig_prior'], num_samples_zmul=mdc_arg['num_samples_zmul'], learning_rate=mdc_arg['learning_rate'], n_epochs=mdc_arg['n_epochs'], regularize=args['regularize']) args['time'] = int(time.time() - t0) row = {'Method': 'MDC.process', 'elbo': elbo} row.update(args) row.update(mdc_arg) row.update(tau['MDC.process']) res.append(row) row = {'Method': 'MDC.mi', 'elbo': elbo} row.update(args) row.update(mdc_arg) row.update(tau['MDC.mi']) res.append(row) log_res(output, res, ['Method'] + list(args.keys()) + l_method_params + l_tau) logging.info('........... DONE') logging.info(f'in {time.time() - exp_time} s \n\n') logging.info('*' * 20) logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.') logging.info('*' * 20)
def exp_miwae(model="dlvm", n=1000, d=3, p=100, prop_miss=0.1, citcio=False, seed=0, d_miwae=3, n_epochs=602, sig_prior=1, add_wy=False, num_samples_zmul=200, method="glm", **kwargs): from miwae import miwae if model == "lrmf": Z, X, w, y, ps = gen_lrmf(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) elif model == "dlvm": Z, X, w, y, ps = gen_dlvm(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) else: raise NotImplementedError( "Other data generating models not implemented here yet.") X_miss = ampute(X, prop_miss=prop_miss, seed=seed) if add_wy: xhat, zhat, zhat_mul = miwae(X_miss, d=d_miwae, sig_prior=sig_prior, num_samples_zmul=num_samples_zmul, n_epochs=n_epochs, add_wy=add_wy, w=w, y=y) else: xhat, zhat, zhat_mul = miwae(X_miss, d=d_miwae, sig_prior=sig_prior, num_samples_zmul=num_samples_zmul, n_epochs=n_epochs, add_wy=add_wy) # print('shape of outputs miwae:') # print('xhat.shape, zhat.shape, zhat_mul.shape:') # (1000, 200) (1000, 3) (200, 1000, 3) print(xhat.shape, zhat.shape, zhat_mul.shape) # Tau estimated on Zhat=E[Z|X] ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat, w, y) res_tau_ols = tau_ols(zhat, w, y) res_tau_ols_ps = tau_ols_ps(zhat, w, y) res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method) lr = LinearRegression() lr.fit(zhat, y) y_hat = lr.predict(zhat) res_tau_resid = tau_residuals(y, w, y_hat, ps_hat, method) # Tau estimated on Zhat^(b), l=1,...,B sampled from posterior res_mul_tau_dr = [] res_mul_tau_ols = [] res_mul_tau_ols_ps = [] res_mul_tau_resid = [] for zhat_b in zhat_mul: ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat_b, w, y) res_mul_tau_dr.append(tau_dr(y, w, y0_hat, y1_hat, ps_hat, method)) res_mul_tau_ols.append(tau_ols(zhat_b, w, y)) res_mul_tau_ols_ps.append(tau_ols_ps(zhat_b, w, y)) lr = LinearRegression() lr.fit(zhat_b, y) y_hat = lr.predict(zhat_b) res_mul_tau_resid.append(tau_residuals(y, w, y_hat, ps_hat, method)) res_mul_tau_dr = np.mean(res_mul_tau_dr) res_mul_tau_ols = np.mean(res_mul_tau_ols) res_mul_tau_ols_ps = np.mean(res_mul_tau_ols_ps) res_mul_tau_resid = np.mean(res_mul_tau_resid) if Z.shape[1] == zhat.shape[1]: dcor_zhat = dcor(Z, zhat) dcor_zhat_mul = [] for zhat_b in zhat_mul: dcor_zhat_mul.append(dcor(Z, zhat_b)) dcor_zhat_mul = np.mean(dcor_zhat_mul) return res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid, res_mul_tau_dr, res_mul_tau_ols, res_mul_tau_ols_ps, res_mul_tau_resid, dcor_zhat, dcor_zhat_mul
def exp_baseline(model="dlvm", n=1000, d=3, p=100, prop_miss=0.1, citcio=False, seed=0, full_baseline=False, method="glm", **kwargs): if model == "lrmf": Z, X, w, y, ps = gen_lrmf(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) elif model == "dlvm": Z, X, w, y, ps = gen_dlvm(n=n, d=d, p=p, citcio=citcio, prop_miss=prop_miss, seed=seed) else: raise NotImplementedError( "Other data generating models not implemented here yet.") X_miss = ampute(X, prop_miss=prop_miss, seed=seed) from sklearn.impute import SimpleImputer X_imp_mean = SimpleImputer().fit_transform(X_miss) Z_perm = np.random.permutation(Z) # Z_rnd = np.random.randn(Z.shape[0], Z.shape[1]) algo_name = ['Z', 'X'] #, 'X_imp_mean'] algo_ = [Z, X] #, X_imp_mean] if full_baseline: # complete the baseline Z_mf = get_U_softimpute(X_miss) # need try-except for sklearn version try: from sklearn.impute import IterativeImputer X_imp = IterativeImputer().fit_transform(X_miss) except: from sklearn.experimental import enable_iterative_imputer from sklearn.impute import IterativeImputer X_imp = IterativeImputer().fit_transform(X_miss) algo_name += ['Z_mf'] #['X_imp','Z_mf']#, 'Z_perm'] algo_ += [Z_mf] #[X_imp, Z_mf]#, Z_perm] tau = dict() for name, zhat in zip(algo_name, algo_): if name == 'X_mi': res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid = tau_mi( zhat, w, y, method=method) else: ps_hat, y0_hat, y1_hat = get_ps_y01_hat(zhat, w, y) res_tau_ols = tau_ols(zhat, w, y) res_tau_ols_ps = tau_ols_ps(zhat, w, y) res_tau_dr = tau_dr(y, w, y0_hat, y1_hat, ps_hat, method) lr = LinearRegression() lr.fit(zhat, y) y_hat = lr.predict(zhat) res_tau_resid = tau_residuals(y, w, y_hat, ps_hat, method) tau[name] = res_tau_dr, res_tau_ols, res_tau_ols_ps, res_tau_resid return tau
def main(unused_argv): # Data generating process parameters exp_parameter_grid = { 'model': ["dlvm", "lrmf"] if FLAGS.model is None else [FLAGS.model], 'citcio': [ False, ], 'n': [1000, 5000, 10000] if FLAGS.n_observations is None else [FLAGS.n_observations], 'p': [10, 50] if FLAGS.p_ambient is None else [FLAGS.p_ambient], 'y_snr': [5.] if FLAGS.y_snr is None else [FLAGS.y_snr], 'x_snr': [2.] if FLAGS.x_snr is None else [FLAGS.x_snr], 'mu_z': [0.] if FLAGS.mu_z is None else [FLAGS.mu_z], 'sig_z': [1.] if FLAGS.sig_z is None else [FLAGS.sig_z], 'sig_xgivenz': [0.001] if FLAGS.sig_xgivenz is None else [FLAGS.sig_xgivenz], 'prop_miss': [ 0.0, ] if FLAGS.prop_miss is None else [FLAGS.prop_miss], 'regularize': [False] if FLAGS.regularize is None else [FLAGS.regularize], 'seed': np.arange(FLAGS.n_seeds), } range_d_over_p = [ 0.002, 0.01, 0.1 ] if FLAGS.d_over_p is None and FLAGS.d_latent is None else [ FLAGS.d_over_p ] range_d = None if range_d_over_p is not None and FLAGS.d_latent is None else FLAGS.d_latent # MDC parameters range_d_offset = [0, 5, 10] if FLAGS.miwae_d_offset is None else [ FLAGS.miwae_d_offset ] mdc_parameter_grid = { 'mu_prior': [0.] if FLAGS.miwae_mu_prior is None else [FLAGS.miwae_mu_prior], 'sig_prior': [1.] if FLAGS.miwae_sig_prior is None else [FLAGS.miwae_sig_prior], 'num_samples_zmul': [500] if FLAGS.miwae_n_samples_zmul is None else [FLAGS.miwae_n_samples_zmul], 'learning_rate': [ 0.0001, ] if FLAGS.miwae_learning_rate is None else [FLAGS.miwae_learning_rate], 'n_epochs': [ 5000, ] if FLAGS.miwae_n_epochs is None else [FLAGS.miwae_n_epochs], } test_seeds = np.arange(FLAGS.n_test_seeds) + 1000 save_test_data = True if FLAGS.save_test_data is None else FLAGS.save_test_data # Experiment and output file name output = f'results/{FLAGS.exp_name}.csv' if FLAGS.output is None else FLAGS.output FLAGS.log_dir = './sessions/logging/' if FLAGS.log_path is None else FLAGS.log_path logging.get_absl_handler().use_absl_log_file() logging.info('*' * 20) logging.info(f'Starting exp: {FLAGS.exp_name}') logging.info('*' * 20) exp_arguments = [ dict(zip(exp_parameter_grid.keys(), vals)) for vals in itertools.product(*exp_parameter_grid.values()) ] previous_runs = set() if tf.io.gfile.exists(output): with tf.io.gfile.GFile(output, mode='r') as f: reader = csv.DictReader(f) for row in reader: # Note: we need to do this conversion because DictReader creates an # OrderedDict, and reads all values as str instead of bool or int. previous_runs.add( str({ 'model': row['model'], 'citcio': row['citcio'] == 'True', 'n': int(row['n']), 'p': int(row['p']), 'y_snr': float(row['y_snr']), 'x_snr': float(row['x_snr']), 'mu_z': float(row['mu_z']), 'sig_z': float(row['sig_z']), 'prop_miss': float(row['prop_miss']), 'regularize': row['regularize'] == 'True', 'seed': int(row['seed']), 'd': int(row['d']), 'sig_xgivenz': float(row['sig_xgivenz']) })) logging.info('Previous runs') logging.info(previous_runs) for args in exp_arguments: # For given p, if range_d is not yet specified, # create range for d such that 1 < d < p # starting with given ratios for d/p if range_d is None: range_d = [ np.maximum(2, int(np.floor(args['p'] * x))) for x in range_d_over_p ] range_d = np.unique( np.array(range_d)[np.array(range_d) < args['p']].tolist()) exp_time = time.time() for args['d'] in range_d: # We only consider cases where latent dimension <= ambient dimension if args['d'] > args['p']: continue res = [] if str(args) in previous_runs: logging.info(f'Skipped {args}') continue else: logging.info(f'running exp with {args}') if args['model'] == "lrmf": Z, X, w, y, ps, mu0, mu1 = gen_lrmf( n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], x_snr=args['x_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed']) elif args['model'] == "dlvm": Z, X, w, y, ps, mu0, mu1 = gen_dlvm( n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=args['seed'], mu_z=args['mu_z'], sig_z=args['sig_z'], x_snr=args['x_snr'], sig_xgivenz=args['sig_xgivenz']) X_miss = ampute(X, prop_miss=args['prop_miss'], seed=args['seed']) # MIWAE mdc_parameter_grid['d_miwae'] = [ args['d'] + x for x in range_d_offset ] mdc_arguments = [ dict(zip(mdc_parameter_grid.keys(), vals)) for vals in itertools.product(*mdc_parameter_grid.values()) ] for mdc_arg in mdc_arguments: t0 = time.time() mdc_arg['mu_prior'] = args['mu_z'] session_file = './sessions/' + \ args['model'] + '_'+ \ '_sigXgivenZ' + str(args['sig_xgivenz']) + \ '_n' + str(args['n']) + \ '_p' + str(args['p']) + \ '_d' + str(args['d']) + \ '_ysnr' + str(args['y_snr']) +\ '_xsnr' + str(args['x_snr']) +\ '_propNA' + str(args['prop_miss']) + \ '_seed' + str(args['seed']) session_file_complete = session_file + \ '_dmiwae' + str(mdc_arg['d_miwae']) + \ '_sigprior' + str(mdc_arg['sig_prior']) epochs = -1 tmp = glob.glob(session_file_complete + '.*') sess = tf.Session(graph=tf.reset_default_graph()) if len(tmp) > 0: new_saver = tf.train.import_meta_graph( session_file_complete + '.meta') new_saver.restore(sess, session_file_complete) #with open(session_file_complete+'.pkl', 'rb') as f: # xhat, zhat, zhat_mul, elbo, epochs = pickle.load(f) else: xhat, zhat, zhat_mul, elbo, epochs = miwae_es( X_miss, d_miwae=mdc_arg['d_miwae'], mu_prior=mdc_arg['mu_prior'], sig_prior=mdc_arg['sig_prior'], num_samples_zmul=mdc_arg['num_samples_zmul'], l_rate=mdc_arg['learning_rate'], n_epochs=mdc_arg['n_epochs'], save_session=True, session_file=session_file) new_saver = tf.train.import_meta_graph( session_file_complete + '.meta') new_saver.restore(sess, session_file_complete ) #tf.train.latest_checkpoint('./')) with open(session_file_complete + '.pkl', 'wb') as file_data: # Python 3: open(..., 'wb') pickle.dump([xhat, zhat, zhat_mul, elbo, epochs], file_data) args['training_time'] = int(time.time() - t0) # Evaluate performance of trained model on new testsets graph = tf.get_default_graph() K = graph.get_tensor_by_name('K:0') x = graph.get_tensor_by_name('x:0') batch_size = tf.shape(x)[0] xms = graph.get_tensor_by_name('xms:0') imp_weights = graph.get_tensor_by_name('imp_weights:0') xm = tf.einsum('ki,kij->ij', imp_weights, xms, name='xm') zgivenx_flat = graph.get_tensor_by_name('zgivenx_flat:0') zgivenx = tf.reshape(zgivenx_flat, [K, batch_size, zgivenx_flat.shape[1]]) z_hat = tf.einsum('ki,kij->ij', imp_weights, zgivenx, name='z_hat') sir_logits = graph.get_tensor_by_name('sir_logits:0') sirz = tfd.Categorical(logits=sir_logits).sample( mdc_arg['num_samples_zmul']) zmul = graph.get_tensor_by_name('zmul:0') for test_seed in test_seeds: if args['model'] == "lrmf": (Z_test, X_test, w_test, y_test, ps_test, mu0_test, mu1_test) = gen_lrmf(n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], citcio=args['citcio'], prop_miss=args['prop_miss'], seed=test_seed) elif args['model'] == "dlvm": ( Z_test, X_test, w_test, y_test, ps_test, mu0_test, mu1_test ) = gen_dlvm( n=args['n'], d=args['d'], p=args['p'], y_snr=args['y_snr'], citcio=args['citcio'], prop_miss=args[ 'prop_miss'], # this argument is only used if citcio=True seed=test_seed, mu_z=args['mu_z'], sig_z=args['sig_z'], x_snr=args['x_snr'], sig_xgivenz=args['sig_xgivenz']) X_miss_test = ampute(X_test, prop_miss=args['prop_miss'], seed=args['seed']) mask_test = np.isfinite( X_miss_test ) # binary mask that indicates which values are missing t0 = time.time() tmp_elm_pkl = glob.glob(session_file_complete + '_testset_eval' + str(test_seed) + '.pkl') if len(tmp_elm_pkl) > 0: with open( session_file_complete + '_testset_eval' + str(test_seed) + '.pkl', 'rb') as f: xhat_test, zhat_test, zgivenx_test, zhat_mul_test = pickle.load( f) else: x_test_imp0 = np.copy(X_miss_test) x_test_imp0[np.isnan(X_miss_test)] = 0 n_test = X_test.shape[0] xhat_test = np.copy(x_test_imp0) zhat_test = np.zeros([n_test, mdc_arg['d_miwae']]) zgivenx_test = np.tile( zhat_test, [mdc_arg['num_samples_zmul'], 1, 1]) zhat_mul_test = np.tile( zhat_test, [mdc_arg['num_samples_zmul'], 1, 1]) for i in range(n_test): zgivenx_test[:, i, :] = np.squeeze( zgivenx.eval(session=sess, feed_dict={ 'x:0': x_test_imp0[i, :].reshape( [1, args['p']]), 'K:0': mdc_arg['num_samples_zmul'], 'xmask:0': mask_test[i, :].reshape( [1, args['p']]) })).reshape([ mdc_arg['num_samples_zmul'], mdc_arg['d_miwae'] ]) xhat_test[i, :] = xm.eval( session=sess, feed_dict={ 'x:0': x_test_imp0[i, :].reshape([1, args['p']]), 'K:0': 10000, 'xmask:0': mask_test[i, :].reshape([1, args['p']]) }) zhat_test[i, :] = z_hat.eval( session=sess, feed_dict={ 'x:0': x_test_imp0[i, :].reshape([1, args['p']]), 'K:0': 10000, 'xmask:0': mask_test[i, :].reshape([1, args['p']]) }) si, zmu = sess.run( [sirz, zmul], feed_dict={ 'x:0': x_test_imp0[i, :].reshape([1, args['p']]), 'K:0': 10000, 'xmask:0': mask_test[i, :].reshape([1, args['p']]) }) zhat_mul_test[:, i, :] = np.squeeze( zmu[si, :, :]).reshape( (mdc_arg['num_samples_zmul'], mdc_arg['d_miwae'])) if save_test_data: with open( session_file_complete + '_testset_eval' + str(test_seed) + '.pkl', 'wb' ) as file_data: # Python 3: open(..., 'wb') pickle.dump([ xhat_test, zhat_test, zgivenx_test, zhat_mul_test ], file_data) evaluation_time = int(time.time() - t0) if args['d'] == 1 and mdc_arg['d_miwae'] == 1: row = { 'Z_cor': pearsonr(Z_test.reshape([ args['n'], ]), zhat_test.reshape([ args['n'], ]))[0] } else: row = {'Z_cor': np.NaN} if args['d'] == mdc_arg['d_miwae']: row.update({'Z_mmd': mmd(Z_test, zhat_test, beta=1.)}) row.update({'Z_rvcoef': compute_rv(Z_test, zhat_test)}) else: row.update({'Z_mmd': np.NaN}) row.update({'Z_rvcoef': np.NaN}) row.update( {'X_mse': mean_squared_error(X_test, xhat_test)}) row.update({'X_mmd': mmd(X_test, xhat_test, beta=1.)}) row.update({'X_rvcoef': compute_rv(X_test, xhat_test)}) row.update(args) row.update(mdc_arg) row.update({'epochs': epochs}) row.update({'test_seed': test_seed}) row.update({'evaluation_time': evaluation_time}) res.append(row) log_res( output, res, l_metrics + list(args.keys()) + list(mdc_arg.keys()) + ['epochs', 'test_seed', 'evaluation_time']) logging.info('........... DONE') logging.info(f'in {time.time() - exp_time} s \n\n') logging.info('*' * 20) logging.info(f'Exp: {FLAGS.exp_name} succesfully ended.') logging.info('*' * 20)
X_new=X_reconstruction, missing_mask=missing_mask) X_filled[missing_mask] = X_reconstruction[missing_mask] if converged: break if self.verbose: print("[SoftImpute] Stopped after iteration %d for lambda=%f" % ( i + 1, shrinkage_value)) self.mae_obs = masked_mae( X_true=X_init, X_pred=X_reconstruction, mask=observed_mask) return X_filled if __name__=='__main__': from generate_data import gen_lrmf, gen_dlvm from generate_data import ampute import matplotlib.pyplot as plt import seaborn as sns Z, X, w, y, ps = gen_lrmf(d=3) X_obs = ampute(X) print('boxplot of get_U_softimpute with gen_lrmf(d=3)') U = get_U_softimpute(X_obs, boxplot=True)