def experiment(sname, seed, datasize, nystr=False, args=None): np.random.seed(1) random.seed(1) def LMO_err(params, M=10): np.random.seed(2) random.seed(2) al, bl = np.exp(params) L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN if nystr: tmp_mat = L @ eig_vec_K C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 + inv_eig_val_K) @ tmp_mat.T / N2 c = C @ W_nystr_Y * N2 else: LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN) C = L @ LWL_inv @ L / N2 c = C @ W @ Y * N2 c_y = c - Y lmo_err = 0 N = 0 for ii in range(1): permutation = np.random.permutation(X.shape[0]) for i in range(0, X.shape[0], M): indices = permutation[i:i + M] K_i = W[np.ix_(indices, indices)] * N2 C_i = C[np.ix_(indices, indices)] c_y_i = c_y[indices] b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i lmo_err += b_y.T @ K_i @ b_y N += 1 return lmo_err[0, 0] / N / M ** 2 def callback0(params, timer=None): global Nfeval, prev_norm, opt_params, opt_test_err np.random.seed(3) random.seed(3) if Nfeval % 1 == 0: al, bl = params L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN if nystr: alpha = EYEN - eig_vec_K @ np.linalg.inv( eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2 alpha = alpha @ W_nystr @ Y * N2 else: LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN) alpha = LWL_inv @ L @ W @ Y # L_W_inv = chol_inv(W*N2+L_inv) test_L = bl * bl * np.exp(-test_L0 / al / al / 2) pred_mean = test_L @ alpha if timer: return test_err = ((pred_mean - test_Y) ** 2).mean() # ((pred_mean-test_Y)**2/np.diag(pred_cov)).mean()+(np.log(np.diag(pred_cov))).mean() norm = alpha.T @ L @ alpha Nfeval += 1 if prev_norm is not None: if norm[0, 0] / prev_norm >= 3: if opt_params is None: opt_test_err = test_err opt_params = params print(True, opt_params, opt_test_err, prev_norm) raise Exception if prev_norm is None or norm[0, 0] <= prev_norm: prev_norm = norm[0, 0] opt_test_err = test_err opt_params = params print('params,test_err, norm: ', opt_params, opt_test_err, prev_norm) def get_causal_effect(params, do_A, w): "to be called within experiment function." np.random.seed(4) random.seed(4) al, bl = params L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN if nystr: alpha = EYEN - eig_vec_K @ np.linalg.inv( eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2 alpha = alpha @ W_nystr @ Y * N2 else: LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN) alpha = LWL_inv @ L @ W @ Y # L_W_inv = chol_inv(W*N2+L_inv) EYhat_do_A = [] for a in do_A: a = np.repeat(a, [w.shape[0]]).reshape(-1, 1) w = w.reshape(-1, 1) aw = np.concatenate([a, w], axis=-1) ate_L0 = _sqdist(aw, X) ate_L = bl * bl * np.exp(-ate_L0 / al / al / 2) h_out = ate_L @ alpha mean_h = np.mean(h_out).reshape(-1, 1) EYhat_do_A.append(mean_h) print('a = {}, beta_a = {}'.format(np.mean(a), mean_h)) return np.concatenate(EYhat_do_A) # train,dev,test = load_data(ROOT_PATH+'/data/zoo/{}_{}.npz'.format(sname,datasize)) # X = np.vstack((train.x,dev.x)) # Y = np.vstack((train.y,dev.y)) # Z = np.vstack((train.z,dev.z)) # test_X = test.x # test_Y = test.g t1 = time.time() train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname + '/main_{}.npz'.format(args.sem)) # train, dev, test = train[:300], dev[:100], test[:100] t2 = time.time() print('t2 - t1 = ', t2 - t1) Y = np.concatenate((train.y, dev.y), axis=0).reshape(-1, 1) # test_Y = test.y AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w) AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w) AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w) X, Z = np.concatenate((AW_train, AW_dev), axis=0), np.concatenate((AZ_train, AZ_dev), axis=0) test_X, test_Y = AW_test, test.y.reshape(-1, 1) # TODO: is test.g just test.y? t3 = time.time() print('t3 - t2', t3-t2) EYEN = np.eye(X.shape[0]) # ak0, ak1 = get_median_inter_mnist(Z[:, 0:1]), get_median_inter_mnist(Z[:, 1:2]) ak = get_median_inter_mnist(Z) N2 = X.shape[0] ** 2 W0 = _sqdist(Z, None) print('av kernel indicator: ', args.av_kernel) # W = np.exp(-W0 / ak0 / ak0 / 2) / N2 if not args.av_kernel \ # else (np.exp(-W0 / ak0 / ak0 / 2) + np.exp(-W0 / ak0 / ak0 / 200) + np.exp(-W0 / ak0 / ak0 * 50)) / 3 / N2 W = np.exp(-W0 / ak / ak / 2) / N2 if not args.av_kernel \ else (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) + np.exp(-W0 / ak / ak * 50)) / 3 / N2 del W0 L0, test_L0 = _sqdist(X, None), _sqdist(test_X, X) t4 = time.time() print('t4 - t3', t4-t3) # measure time # callback0(np.random.randn(2)/10,True) # np.save(ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + '/LMO_errs_{}_nystr_{}_time.npy'.format(seed,train.x.shape[0]),time.time()-t0) # return # params0 = np.random.randn(2) # /10 params0 = np.array([1., 1.]) print('starting param: ', params0) bounds = None # [[0.01,10],[0.01,5]] if nystr: for _ in range(seed + 1): random_indices = np.sort(np.random.choice(range(W.shape[0]), nystr_M, replace=False)) eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices) inv_eig_val_K = np.diag(1 / eig_val_K / N2) W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2 W_nystr_Y = W_nystr @ Y t5 = time.time() print('t5 - t4', t5-t4) obj_grad = value_and_grad(lambda params: LMO_err(params)) try: res = minimize(obj_grad, x0=params0, bounds=bounds, method='L-BFGS-B', jac=True, options={'maxiter': 5000}, callback=callback0) # res stands for results (not residuals!). except Exception as e: print(e) PATH = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/" if not os.path.exists(PATH+str(date.today())): os.mkdir(PATH + str(date.today())) assert opt_params is not None params = opt_params do_A = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['do_A'] EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A'] w_sample = train.w EYhat_do_A = get_causal_effect(params=params, do_A=do_A, w=w_sample) plt.figure() plt.plot([i + 1 for i in range(20)], EYhat_do_A) plt.xlabel('A') plt.ylabel('EYdoA-est') plt.savefig( os.path.join(PATH, str(date.today()), 'causal_effect_estimates_nystr_{}'.format(AW_train.shape[0]) + '.png')) plt.close() print('ground truth ate: ', EY_do_A_gt) visualise_ATEs(EY_do_A_gt, EYhat_do_A, x_name='E[Y|do(A)] - gt', y_name='beta_A', save_loc=os.path.join(PATH, str(date.today())) + '/', save_name='ate_{}_nystr.png'.format(AW_train.shape[0])) causal_effect_mean_abs_err = np.mean(np.abs(EY_do_A_gt - EYhat_do_A)) causal_effect_mae_file = open(os.path.join(PATH, str(date.today()), "ate_mae_{}_nystrom.txt".format(AW_train.shape[0])), "a") causal_effect_mae_file.write("mae_: {}\n".format(causal_effect_mean_abs_err)) causal_effect_mae_file.close() os.makedirs(PATH, exist_ok=True) np.save(os.path.join(PATH, str(date.today()), 'LMO_errs_{}_nystr_{}.npy'.format(seed, AW_train.shape[0])), [opt_params, prev_norm, opt_test_err])
def run_experiment_nn(sname, datasize, indices=[], seed=527, training=True, args=None): torch.manual_seed(seed) np.random.seed(seed) if len(indices) == 2: lr_id, dw_id = indices elif len(indices) == 3: lr_id, dw_id, W_id = indices # load data folder = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/" os.makedirs(folder, exist_ok=True) train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname + '/main_{}.npz'.format(args.sem), Torch=True) Y = torch.cat((train.y, dev.y), dim=0).float() AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w, Torch=True) AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w, Torch=True) AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w, Torch=True) X, Z = torch.cat((AW_train, AW_dev), dim=0).float(), torch.cat( (AZ_train, AZ_dev), dim=0).float() test_X, test_Y = AW_test.float(), test.y.float( ) # TODO: is test.g just test.y? n_train = train.a.shape[0] # training settings n_epochs = 1000 batch_size = 1000 if train.a.shape[0] > 1000 else train.a.shape[0] # load expectation eval data axzy = np.load(ROOT_PATH + "/data/zoo/" + sname + '/cond_exp_metric_{}.npz'.format(args.sem))['axzy'] w_samples = np.load( ROOT_PATH + "/data/zoo/" + sname + '/cond_exp_metric_{}.npz'.format(args.sem))['w_samples'] y_samples = np.load( ROOT_PATH + "/data/zoo/" + sname + '/cond_exp_metric_{}.npz'.format(args.sem))['y_samples'] y_axz = axzy[:, -1] ax = axzy[:, :2] # kernel kernel = Kernel('rbf', Torch=True) a = get_median_inter_mnist(AZ_train) a = torch.tensor(a).float() # training loop lrs = [2e-4, 1e-4, 5e-5] # [3,5] decay_weights = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6] # [11,5] def my_loss(output, target, indices, K): d = output - target if indices is None: W = K else: W = K[indices[:, None], indices] # print((kernel(Z[indices],None,a,1)+kernel(Z[indices],None,a/10,1)+kernel(Z[indices],None,a*10,1))/3-W) loss = d.T @ W @ d / (d.shape[0])**2 return loss[0, 0] def conditional_expected_loss(net, ax, w_samples, y_samples, y_axz, x_on): if not x_on: ax = ax[:, 0:1] num_reps = w_samples.shape[1] assert len(ax.shape) == 2 assert ax.shape[1] < 3 assert ax.shape[0] == w_samples.shape[0] print('number of points: ', w_samples.shape[0]) ax_rep = np.repeat(ax, [num_reps], axis=0) assert ax_rep.shape[0] == (w_samples.shape[1] * ax.shape[0]) w_samples_flat = w_samples.flatten().reshape(-1, 1) nn_inp_np = np.concatenate([ax_rep, w_samples_flat], axis=-1) # print('nn_inp shape: ', nn_inp_np.shape) nn_inp = torch.as_tensor(nn_inp_np).float() nn_out = net(nn_inp).detach().cpu().numpy() nn_out = nn_out.reshape([-1, w_samples.shape[1]]) y_axz_recon = np.mean(nn_out, axis=1) assert y_axz_recon.shape[0] == y_axz.shape[0] mean_abs_error = np.mean(np.abs(y_axz - y_axz_recon)) # for debugging compute the mse between y samples and h y_samples_flat = y_samples.flatten() mse = np.mean((y_samples_flat - nn_out.flatten())**2) return mean_abs_error, mse def fit(x, y, z, dev_x, dev_y, dev_z, a, lr, decay_weight, ax, y_axz, w_samples, n_epochs=n_epochs): if 'mnist' in sname: train_K = torch.eye(x.shape[0]) else: train_K = (kernel(z, None, a, 1) + kernel(z, None, a / 10, 1) + kernel(z, None, a * 10, 1)) / 3 if dev_z is not None: if 'mnist' in sname: dev_K = torch.eye(x.shape[0]) else: dev_K = (kernel(dev_z, None, a, 1) + kernel(dev_z, None, a / 10, 1) + kernel(dev_z, None, a * 10, 1)) / 3 n_data = x.shape[0] net = FCNN(x.shape[1]) if sname not in ['mnist_x', 'mnist_xz' ] else CNN() es = EarlyStopping(patience=10) # 10 for small optimizer = optim.Adam(list(net.parameters()), lr=lr, weight_decay=decay_weight) test_errs, dev_errs, exp_errs, mse_s = [], [], [], [] for epoch in range(n_epochs): permutation = torch.randperm(n_data) for i in range(0, n_data, batch_size): indices = permutation[i:i + batch_size] batch_x, batch_y = x[indices], y[indices] # training loop def closure(): optimizer.zero_grad() pred_y = net(batch_x) loss = my_loss(pred_y, batch_y, indices, train_K) loss.backward() return loss optimizer.step(closure) # Does the update if epoch % 5 == 0 and epoch >= 50 and dev_x is not None: # 5, 10 for small # 5,50 for large g_pred = net( test_X ) # TODO: is it supposed to be test_X here? A: yes I think so. test_err = ((g_pred - test_Y)**2).mean( ) # TODO: why isn't this loss reweighted? A: because it is supposed to measure the agreement between prediction and labels. if epoch == 50 and 'mnist' in sname: if z.shape[1] > 100: train_K = np.load( ROOT_PATH + '/mnist_precomp/{}_train_K0.npy'.format(sname)) train_K = (torch.exp(-train_K / a**2 / 2) + torch.exp(-train_K / a**2 * 50) + torch.exp(-train_K / a**2 / 200)) / 3 dev_K = np.load( ROOT_PATH + '/mnist_precomp/{}_dev_K0.npy'.format(sname)) dev_K = (torch.exp(-dev_K / a**2 / 2) + torch.exp(-dev_K / a**2 * 50) + torch.exp(-dev_K / a**2 / 200)) / 3 else: train_K = (kernel(z, None, a, 1) + kernel(z, None, a / 10, 1) + kernel(z, None, a * 10, 1)) / 3 dev_K = (kernel(dev_z, None, a, 1) + kernel(dev_z, None, a / 10, 1) + kernel(dev_z, None, a * 10, 1)) / 3 dev_err = my_loss(net(dev_x), dev_y, None, dev_K) err_in_expectation, mse = conditional_expected_loss( net=net, ax=ax, w_samples=w_samples, y_samples=y_samples, y_axz=y_axz, x_on=False) print('test', test_err, 'dev', dev_err, 'err_in_expectation', err_in_expectation, 'mse: ', mse) test_errs.append(test_err) dev_errs.append(dev_err) exp_errs.append(err_in_expectation) mse_s.append(mse) if es.step(dev_err): break losses = { 'test': test_errs, 'dev': dev_errs, 'exp': exp_errs, 'mse_': mse_s } return es.best, epoch, net, losses def get_causal_effect(net, do_A, w): """ :param net: FCNN object :param do_A: a numpy array of interventions, size = B_a :param w: a torch tensor of w samples, size = B_w :return: a numpy array of interventional parameters """ net.eval() # raise ValueError('have not tested get_causal_effect.') EYhat_do_A = [] for a in do_A: a = np.repeat(a, [w.shape[0]]).reshape(-1, 1) a_tensor = torch.as_tensor(a).float() w = w.reshape(-1, 1).float() aw = torch.cat([a_tensor, w], dim=-1) aw_tensor = torch.tensor(aw) mean_h = torch.mean(net(aw_tensor)).reshape(-1, 1) EYhat_do_A.append(mean_h) print('a = {}, beta_a = {}'.format(np.mean(a), mean_h)) return torch.cat(EYhat_do_A).detach().cpu().numpy() if training is True: print('training') for rep in range(3): print('*******REP: {}'.format(rep)) save_path = os.path.join( folder, 'mmr_iv_nn_{}_{}_{}_{}.npz'.format(rep, lr_id, dw_id, AW_train.shape[0])) # if os.path.exists(save_path): # continue lr, dw = lrs[lr_id], decay_weights[dw_id] print('lr, dw', lr, dw) t0 = time.time() err, _, net, losses = fit(X[:n_train], Y[:n_train], Z[:n_train], X[n_train:], Y[n_train:], Z[n_train:], a, lr, dw, ax=ax, y_axz=y_axz, w_samples=w_samples) t1 = time.time() - t0 np.save( folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format( rep, lr_id, dw_id, AW_train.shape[0]), t1) g_pred = net(test_X).detach().numpy() test_err = ((g_pred - test_Y.numpy())**2).mean() np.savez(save_path, err=err.detach().numpy(), lr=lr, dw=dw, g_pred=g_pred, test_err=test_err) # make loss curves for (name, ylabel) in [('test', 'test av ||y - h||^2'), ('dev', 'R_V'), ('exp', 'E[y-h|a,z,x]'), ('mse_', 'mse_alternative_sim')]: errs = losses[name] stps = [50 + i * 5 for i in range(len(errs))] plt.figure() plt.plot(stps, errs) plt.xlabel('epoch') plt.ylabel(ylabel) plt.savefig( os.path.join( folder, name + '_{}_{}_{}_{}'.format( rep, lr_id, dw_id, AW_train.shape[0]) + '.png')) plt.close() # do causal effect estimates do_A = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['do_A'] EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A'] w_sample = train.w EYhat_do_A = get_causal_effect(net, do_A=do_A, w=w_sample) plt.figure() plt.plot([i + 1 for i in range(20)], EYhat_do_A) plt.xlabel('A') plt.ylabel('EYdoA-est') plt.savefig( os.path.join( folder, 'causal_effect_estimates_{}_{}_{}'.format( lr_id, dw_id, AW_train.shape[0]) + '.png')) plt.close() print('ground truth ate: ', EY_do_A_gt) visualise_ATEs(EY_do_A_gt, EYhat_do_A, x_name='E[Y|do(A)] - gt', y_name='beta_A', save_loc=folder, save_name='ate_{}_{}_{}_{}.png'.format( rep, lr_id, dw_id, AW_train.shape[0])) causal_effect_mean_abs_err = np.mean( np.abs(EY_do_A_gt - EYhat_do_A)) causal_effect_mae_file = open( os.path.join( folder, "ate_mae_{}_{}_{}.txt".format(lr_id, dw_id, AW_train.shape[0])), "a") causal_effect_mae_file.write("mae_rep_{}: {}\n".format( rep, causal_effect_mean_abs_err)) causal_effect_mae_file.close() else: print('test') opt_res = [] times = [] for rep in range(10): res_list = [] other_list = [] times2 = [] for lr_id in range(len(lrs)): for dw_id in range(len(decay_weights)): load_path = os.path.join( folder, 'mmr_iv_nn_{}_{}_{}_{}.npz'.format( rep, lr_id, dw_id, datasize)) if os.path.exists(load_path): res = np.load(load_path) res_list += [res['err'].astype(float)] other_list += [[ res['lr'].astype(float), res['dw'].astype(float), res['test_err'].astype(float) ]] time_path = folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format( rep, lr_id, dw_id, datasize) if os.path.exists(time_path): t = np.load(time_path) times2 += [t] res_list = np.array(res_list) other_list = np.array(other_list) other_list = other_list[res_list > 0] res_list = res_list[res_list > 0] optim_id = np.argsort(res_list)[0] # np.argmin(res_list) print(rep, '--', other_list[optim_id], np.min(res_list)) opt_res += [other_list[optim_id][-1]] print('time: ', np.mean(times), np.std(times)) print(np.mean(opt_res), np.std(opt_res))