コード例 #1
0
def experiment(sname, seed, datasize, nystr=False, args=None):
    np.random.seed(1)
    random.seed(1)
    def LMO_err(params, M=10):
        np.random.seed(2)
        random.seed(2)
        al, bl = np.exp(params)
        L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
        if nystr:
            tmp_mat = L @ eig_vec_K
            C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 + inv_eig_val_K) @ tmp_mat.T / N2
            c = C @ W_nystr_Y * N2
        else:
            LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
            C = L @ LWL_inv @ L / N2
            c = C @ W @ Y * N2
        c_y = c - Y
        lmo_err = 0
        N = 0
        for ii in range(1):
            permutation = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], M):
                indices = permutation[i:i + M]
                K_i = W[np.ix_(indices, indices)] * N2
                C_i = C[np.ix_(indices, indices)]
                c_y_i = c_y[indices]
                b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i
                lmo_err += b_y.T @ K_i @ b_y
                N += 1
        return lmo_err[0, 0] / N / M ** 2

    def callback0(params, timer=None):
        global Nfeval, prev_norm, opt_params, opt_test_err
        np.random.seed(3)
        random.seed(3)
        if Nfeval % 1 == 0:
            al, bl = params
            L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
            if nystr:
                alpha = EYEN - eig_vec_K @ np.linalg.inv(
                    eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
                alpha = alpha @ W_nystr @ Y * N2
            else:
                LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
                alpha = LWL_inv @ L @ W @ Y
                # L_W_inv = chol_inv(W*N2+L_inv)
            test_L = bl * bl * np.exp(-test_L0 / al / al / 2)
            pred_mean = test_L @ alpha
            if timer:
                return
            test_err = ((pred_mean - test_Y) ** 2).mean()  # ((pred_mean-test_Y)**2/np.diag(pred_cov)).mean()+(np.log(np.diag(pred_cov))).mean()
            norm = alpha.T @ L @ alpha

        Nfeval += 1
        if prev_norm is not None:
            if norm[0, 0] / prev_norm >= 3:
                if opt_params is None:
                    opt_test_err = test_err
                    opt_params = params
                print(True, opt_params, opt_test_err, prev_norm)
                raise Exception

        if prev_norm is None or norm[0, 0] <= prev_norm:
            prev_norm = norm[0, 0]
        opt_test_err = test_err
        opt_params = params
        print('params,test_err, norm: ', opt_params, opt_test_err, prev_norm)

    def get_causal_effect(params, do_A, w):
        "to be called within experiment function."
        np.random.seed(4)
        random.seed(4)
        al, bl = params
        L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
        if nystr:
            alpha = EYEN - eig_vec_K @ np.linalg.inv(
                eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
            alpha = alpha @ W_nystr @ Y * N2
        else:
            LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
            alpha = LWL_inv @ L @ W @ Y
            # L_W_inv = chol_inv(W*N2+L_inv)

        EYhat_do_A = []
        for a in do_A:
            a = np.repeat(a, [w.shape[0]]).reshape(-1, 1)
            w = w.reshape(-1, 1)
            aw = np.concatenate([a, w], axis=-1)
            ate_L0 = _sqdist(aw, X)
            ate_L = bl * bl * np.exp(-ate_L0 / al / al / 2)
            h_out = ate_L @ alpha

            mean_h = np.mean(h_out).reshape(-1, 1)
            EYhat_do_A.append(mean_h)
            print('a = {}, beta_a = {}'.format(np.mean(a), mean_h))

        return np.concatenate(EYhat_do_A)

    # train,dev,test = load_data(ROOT_PATH+'/data/zoo/{}_{}.npz'.format(sname,datasize))

    # X = np.vstack((train.x,dev.x))
    # Y = np.vstack((train.y,dev.y))
    # Z = np.vstack((train.z,dev.z))
    # test_X = test.x
    # test_Y = test.g
    t1 = time.time()
    train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname + '/main_{}.npz'.format(args.sem))
    # train, dev, test = train[:300], dev[:100], test[:100]
    t2 = time.time()
    print('t2 - t1 = ', t2 - t1)
    Y = np.concatenate((train.y, dev.y), axis=0).reshape(-1, 1)
    # test_Y = test.y
    AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w)
    AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w)
    AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w)

    X, Z = np.concatenate((AW_train, AW_dev), axis=0), np.concatenate((AZ_train, AZ_dev), axis=0)
    test_X, test_Y = AW_test, test.y.reshape(-1, 1)  # TODO: is test.g just test.y?

    t3 = time.time()
    print('t3 - t2', t3-t2)
    EYEN = np.eye(X.shape[0])
    # ak0, ak1 = get_median_inter_mnist(Z[:, 0:1]), get_median_inter_mnist(Z[:, 1:2])
    ak = get_median_inter_mnist(Z)
    N2 = X.shape[0] ** 2
    W0 = _sqdist(Z, None)
    print('av kernel indicator: ', args.av_kernel)
    # W = np.exp(-W0 / ak0 / ak0 / 2) / N2 if not args.av_kernel \
    #     else (np.exp(-W0 / ak0 / ak0 / 2) + np.exp(-W0 / ak0 / ak0 / 200) + np.exp(-W0 / ak0 / ak0 * 50)) / 3 / N2
    W = np.exp(-W0 / ak / ak / 2) / N2 if not args.av_kernel \
        else (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) + np.exp(-W0 / ak / ak * 50)) / 3 / N2

    del W0
    L0, test_L0 = _sqdist(X, None), _sqdist(test_X, X)
    t4 = time.time()
    print('t4 - t3', t4-t3)
    # measure time
    # callback0(np.random.randn(2)/10,True)
    # np.save(ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + '/LMO_errs_{}_nystr_{}_time.npy'.format(seed,train.x.shape[0]),time.time()-t0)
    # return

    # params0 = np.random.randn(2)  # /10
    params0 = np.array([1., 1.])
    print('starting param: ', params0)
    bounds = None  # [[0.01,10],[0.01,5]]
    if nystr:
        for _ in range(seed + 1):
            random_indices = np.sort(np.random.choice(range(W.shape[0]), nystr_M, replace=False))
        eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices)
        inv_eig_val_K = np.diag(1 / eig_val_K / N2)
        W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2
        W_nystr_Y = W_nystr @ Y

    t5 = time.time()
    print('t5 - t4', t5-t4)
    obj_grad = value_and_grad(lambda params: LMO_err(params))
    try:
        res = minimize(obj_grad, x0=params0, bounds=bounds, method='L-BFGS-B', jac=True, options={'maxiter': 5000},
                   callback=callback0)
    # res stands for results (not residuals!).
    except Exception as e:
        print(e)

    PATH = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/"
    if not os.path.exists(PATH+str(date.today())):
        os.mkdir(PATH + str(date.today()))

    assert opt_params is not None
    params = opt_params
    do_A = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['do_A']
    EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A']
    w_sample = train.w
    EYhat_do_A = get_causal_effect(params=params, do_A=do_A, w=w_sample)
    plt.figure()
    plt.plot([i + 1 for i in range(20)], EYhat_do_A)
    plt.xlabel('A')
    plt.ylabel('EYdoA-est')
    plt.savefig(
        os.path.join(PATH, str(date.today()), 'causal_effect_estimates_nystr_{}'.format(AW_train.shape[0]) + '.png'))
    plt.close()
    print('ground truth ate: ', EY_do_A_gt)
    visualise_ATEs(EY_do_A_gt, EYhat_do_A,
                   x_name='E[Y|do(A)] - gt',
                   y_name='beta_A',
                   save_loc=os.path.join(PATH, str(date.today())) + '/',
                   save_name='ate_{}_nystr.png'.format(AW_train.shape[0]))
    causal_effect_mean_abs_err = np.mean(np.abs(EY_do_A_gt - EYhat_do_A))
    causal_effect_mae_file = open(os.path.join(PATH, str(date.today()), "ate_mae_{}_nystrom.txt".format(AW_train.shape[0])),
                                  "a")
    causal_effect_mae_file.write("mae_: {}\n".format(causal_effect_mean_abs_err))
    causal_effect_mae_file.close()

    os.makedirs(PATH, exist_ok=True)
    np.save(os.path.join(PATH, str(date.today()), 'LMO_errs_{}_nystr_{}.npy'.format(seed, AW_train.shape[0])), [opt_params, prev_norm, opt_test_err])
コード例 #2
0
def run_experiment_nn(sname,
                      datasize,
                      indices=[],
                      seed=527,
                      training=True,
                      args=None):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if len(indices) == 2:
        lr_id, dw_id = indices
    elif len(indices) == 3:
        lr_id, dw_id, W_id = indices
    # load data
    folder = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/"
    os.makedirs(folder, exist_ok=True)

    train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname +
                                 '/main_{}.npz'.format(args.sem),
                                 Torch=True)
    Y = torch.cat((train.y, dev.y), dim=0).float()
    AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w, Torch=True)
    AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w, Torch=True)
    AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w, Torch=True)

    X, Z = torch.cat((AW_train, AW_dev), dim=0).float(), torch.cat(
        (AZ_train, AZ_dev), dim=0).float()
    test_X, test_Y = AW_test.float(), test.y.float(
    )  # TODO: is test.g just test.y?
    n_train = train.a.shape[0]
    # training settings
    n_epochs = 1000
    batch_size = 1000 if train.a.shape[0] > 1000 else train.a.shape[0]

    # load expectation eval data
    axzy = np.load(ROOT_PATH + "/data/zoo/" + sname +
                   '/cond_exp_metric_{}.npz'.format(args.sem))['axzy']
    w_samples = np.load(
        ROOT_PATH + "/data/zoo/" + sname +
        '/cond_exp_metric_{}.npz'.format(args.sem))['w_samples']
    y_samples = np.load(
        ROOT_PATH + "/data/zoo/" + sname +
        '/cond_exp_metric_{}.npz'.format(args.sem))['y_samples']
    y_axz = axzy[:, -1]
    ax = axzy[:, :2]

    # kernel
    kernel = Kernel('rbf', Torch=True)
    a = get_median_inter_mnist(AZ_train)
    a = torch.tensor(a).float()
    # training loop
    lrs = [2e-4, 1e-4, 5e-5]  # [3,5]
    decay_weights = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6]  # [11,5]

    def my_loss(output, target, indices, K):
        d = output - target
        if indices is None:
            W = K
        else:
            W = K[indices[:, None], indices]
            # print((kernel(Z[indices],None,a,1)+kernel(Z[indices],None,a/10,1)+kernel(Z[indices],None,a*10,1))/3-W)
        loss = d.T @ W @ d / (d.shape[0])**2
        return loss[0, 0]

    def conditional_expected_loss(net, ax, w_samples, y_samples, y_axz, x_on):
        if not x_on:
            ax = ax[:, 0:1]
        num_reps = w_samples.shape[1]
        assert len(ax.shape) == 2
        assert ax.shape[1] < 3
        assert ax.shape[0] == w_samples.shape[0]
        print('number of points: ', w_samples.shape[0])

        ax_rep = np.repeat(ax, [num_reps], axis=0)
        assert ax_rep.shape[0] == (w_samples.shape[1] * ax.shape[0])

        w_samples_flat = w_samples.flatten().reshape(-1, 1)
        nn_inp_np = np.concatenate([ax_rep, w_samples_flat], axis=-1)
        # print('nn_inp shape: ', nn_inp_np.shape)
        nn_inp = torch.as_tensor(nn_inp_np).float()
        nn_out = net(nn_inp).detach().cpu().numpy()
        nn_out = nn_out.reshape([-1, w_samples.shape[1]])
        y_axz_recon = np.mean(nn_out, axis=1)
        assert y_axz_recon.shape[0] == y_axz.shape[0]
        mean_abs_error = np.mean(np.abs(y_axz - y_axz_recon))

        # for debugging compute the mse between y samples and h
        y_samples_flat = y_samples.flatten()
        mse = np.mean((y_samples_flat - nn_out.flatten())**2)

        return mean_abs_error, mse

    def fit(x,
            y,
            z,
            dev_x,
            dev_y,
            dev_z,
            a,
            lr,
            decay_weight,
            ax,
            y_axz,
            w_samples,
            n_epochs=n_epochs):
        if 'mnist' in sname:
            train_K = torch.eye(x.shape[0])
        else:
            train_K = (kernel(z, None, a, 1) + kernel(z, None, a / 10, 1) +
                       kernel(z, None, a * 10, 1)) / 3
        if dev_z is not None:
            if 'mnist' in sname:
                dev_K = torch.eye(x.shape[0])
            else:
                dev_K = (kernel(dev_z, None, a, 1) +
                         kernel(dev_z, None, a / 10, 1) +
                         kernel(dev_z, None, a * 10, 1)) / 3
        n_data = x.shape[0]
        net = FCNN(x.shape[1]) if sname not in ['mnist_x', 'mnist_xz'
                                                ] else CNN()
        es = EarlyStopping(patience=10)  # 10 for small
        optimizer = optim.Adam(list(net.parameters()),
                               lr=lr,
                               weight_decay=decay_weight)

        test_errs, dev_errs, exp_errs, mse_s = [], [], [], []

        for epoch in range(n_epochs):
            permutation = torch.randperm(n_data)

            for i in range(0, n_data, batch_size):
                indices = permutation[i:i + batch_size]
                batch_x, batch_y = x[indices], y[indices]

                # training loop
                def closure():
                    optimizer.zero_grad()
                    pred_y = net(batch_x)
                    loss = my_loss(pred_y, batch_y, indices, train_K)
                    loss.backward()
                    return loss

                optimizer.step(closure)  # Does the update
            if epoch % 5 == 0 and epoch >= 50 and dev_x is not None:  # 5, 10 for small # 5,50 for large
                g_pred = net(
                    test_X
                )  # TODO: is it supposed to be test_X here? A: yes I think so.
                test_err = ((g_pred - test_Y)**2).mean(
                )  # TODO: why isn't this loss reweighted? A: because it is supposed to measure the agreement between prediction and labels.
                if epoch == 50 and 'mnist' in sname:
                    if z.shape[1] > 100:
                        train_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_train_K0.npy'.format(sname))
                        train_K = (torch.exp(-train_K / a**2 / 2) +
                                   torch.exp(-train_K / a**2 * 50) +
                                   torch.exp(-train_K / a**2 / 200)) / 3
                        dev_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_dev_K0.npy'.format(sname))
                        dev_K = (torch.exp(-dev_K / a**2 / 2) +
                                 torch.exp(-dev_K / a**2 * 50) +
                                 torch.exp(-dev_K / a**2 / 200)) / 3
                    else:
                        train_K = (kernel(z, None, a, 1) +
                                   kernel(z, None, a / 10, 1) +
                                   kernel(z, None, a * 10, 1)) / 3
                        dev_K = (kernel(dev_z, None, a, 1) +
                                 kernel(dev_z, None, a / 10, 1) +
                                 kernel(dev_z, None, a * 10, 1)) / 3

                dev_err = my_loss(net(dev_x), dev_y, None, dev_K)
                err_in_expectation, mse = conditional_expected_loss(
                    net=net,
                    ax=ax,
                    w_samples=w_samples,
                    y_samples=y_samples,
                    y_axz=y_axz,
                    x_on=False)
                print('test', test_err, 'dev', dev_err, 'err_in_expectation',
                      err_in_expectation, 'mse: ', mse)
                test_errs.append(test_err)
                dev_errs.append(dev_err)
                exp_errs.append(err_in_expectation)
                mse_s.append(mse)

                if es.step(dev_err):
                    break
            losses = {
                'test': test_errs,
                'dev': dev_errs,
                'exp': exp_errs,
                'mse_': mse_s
            }
        return es.best, epoch, net, losses

    def get_causal_effect(net, do_A, w):
        """
        :param net: FCNN object
        :param do_A: a numpy array of interventions, size = B_a
        :param w: a torch tensor of w samples, size = B_w
        :return: a numpy array of interventional parameters
        """
        net.eval()
        # raise ValueError('have not tested get_causal_effect.')
        EYhat_do_A = []
        for a in do_A:
            a = np.repeat(a, [w.shape[0]]).reshape(-1, 1)
            a_tensor = torch.as_tensor(a).float()
            w = w.reshape(-1, 1).float()
            aw = torch.cat([a_tensor, w], dim=-1)
            aw_tensor = torch.tensor(aw)
            mean_h = torch.mean(net(aw_tensor)).reshape(-1, 1)
            EYhat_do_A.append(mean_h)
            print('a = {}, beta_a = {}'.format(np.mean(a), mean_h))
        return torch.cat(EYhat_do_A).detach().cpu().numpy()

    if training is True:
        print('training')
        for rep in range(3):
            print('*******REP: {}'.format(rep))
            save_path = os.path.join(
                folder,
                'mmr_iv_nn_{}_{}_{}_{}.npz'.format(rep, lr_id, dw_id,
                                                   AW_train.shape[0]))
            # if os.path.exists(save_path):
            #    continue
            lr, dw = lrs[lr_id], decay_weights[dw_id]
            print('lr, dw', lr, dw)
            t0 = time.time()
            err, _, net, losses = fit(X[:n_train],
                                      Y[:n_train],
                                      Z[:n_train],
                                      X[n_train:],
                                      Y[n_train:],
                                      Z[n_train:],
                                      a,
                                      lr,
                                      dw,
                                      ax=ax,
                                      y_axz=y_axz,
                                      w_samples=w_samples)
            t1 = time.time() - t0
            np.save(
                folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                    rep, lr_id, dw_id, AW_train.shape[0]), t1)
            g_pred = net(test_X).detach().numpy()
            test_err = ((g_pred - test_Y.numpy())**2).mean()
            np.savez(save_path,
                     err=err.detach().numpy(),
                     lr=lr,
                     dw=dw,
                     g_pred=g_pred,
                     test_err=test_err)

            # make loss curves
            for (name, ylabel) in [('test', 'test av ||y - h||^2'),
                                   ('dev', 'R_V'), ('exp', 'E[y-h|a,z,x]'),
                                   ('mse_', 'mse_alternative_sim')]:
                errs = losses[name]
                stps = [50 + i * 5 for i in range(len(errs))]
                plt.figure()
                plt.plot(stps, errs)
                plt.xlabel('epoch')
                plt.ylabel(ylabel)
                plt.savefig(
                    os.path.join(
                        folder, name + '_{}_{}_{}_{}'.format(
                            rep, lr_id, dw_id, AW_train.shape[0]) + '.png'))
                plt.close()

            # do causal effect estimates
            do_A = np.load(ROOT_PATH + "/data/zoo/" + sname +
                           '/do_A_{}.npz'.format(args.sem))['do_A']
            EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname +
                                 '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A']
            w_sample = train.w
            EYhat_do_A = get_causal_effect(net, do_A=do_A, w=w_sample)
            plt.figure()
            plt.plot([i + 1 for i in range(20)], EYhat_do_A)
            plt.xlabel('A')
            plt.ylabel('EYdoA-est')
            plt.savefig(
                os.path.join(
                    folder, 'causal_effect_estimates_{}_{}_{}'.format(
                        lr_id, dw_id, AW_train.shape[0]) + '.png'))
            plt.close()

            print('ground truth ate: ', EY_do_A_gt)
            visualise_ATEs(EY_do_A_gt,
                           EYhat_do_A,
                           x_name='E[Y|do(A)] - gt',
                           y_name='beta_A',
                           save_loc=folder,
                           save_name='ate_{}_{}_{}_{}.png'.format(
                               rep, lr_id, dw_id, AW_train.shape[0]))
            causal_effect_mean_abs_err = np.mean(
                np.abs(EY_do_A_gt - EYhat_do_A))
            causal_effect_mae_file = open(
                os.path.join(
                    folder,
                    "ate_mae_{}_{}_{}.txt".format(lr_id, dw_id,
                                                  AW_train.shape[0])), "a")
            causal_effect_mae_file.write("mae_rep_{}: {}\n".format(
                rep, causal_effect_mean_abs_err))
            causal_effect_mae_file.close()

    else:
        print('test')
        opt_res = []
        times = []
        for rep in range(10):
            res_list = []
            other_list = []
            times2 = []
            for lr_id in range(len(lrs)):
                for dw_id in range(len(decay_weights)):
                    load_path = os.path.join(
                        folder, 'mmr_iv_nn_{}_{}_{}_{}.npz'.format(
                            rep, lr_id, dw_id, datasize))
                    if os.path.exists(load_path):
                        res = np.load(load_path)
                        res_list += [res['err'].astype(float)]
                        other_list += [[
                            res['lr'].astype(float), res['dw'].astype(float),
                            res['test_err'].astype(float)
                        ]]
                    time_path = folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                        rep, lr_id, dw_id, datasize)
                    if os.path.exists(time_path):
                        t = np.load(time_path)
                        times2 += [t]
            res_list = np.array(res_list)
            other_list = np.array(other_list)
            other_list = other_list[res_list > 0]
            res_list = res_list[res_list > 0]
            optim_id = np.argsort(res_list)[0]  # np.argmin(res_list)
            print(rep, '--', other_list[optim_id], np.min(res_list))
            opt_res += [other_list[optim_id][-1]]
        print('time: ', np.mean(times), np.std(times))
        print(np.mean(opt_res), np.std(opt_res))