Example #1
0
def experiment(nystr=True, IV=True):
    def LMO_err(params, M=2):
        params = np.exp(params)
        al, bl = params[:-1], params[-1]
        L = bl * bl * np.exp(-L0[0] / al[0] / al[0] / 2) + bl * bl * np.exp(
            -L0[1] / al[1] / al[1] /
            2) + 1e-6 * EYEN  # l(X,None,al,bl)# +1e-6*EYEN
        if nystr:
            tmp_mat = L @ eig_vec_K
            C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 +
                                            inv_eig_val_K) @ tmp_mat.T / N2
            c = C @ W_nystr_Y * N2
        else:
            LWL_inv = chol_inv(
                L @ W @ L + L / N2 + JITTER * EYEN
            )  # chol_inv(W*N2+L_inv) # chol_inv(L@W@L+L/N2 +JITTER*EYEN)
            C = L @ LWL_inv @ L / N2
            c = C @ W @ Y * N2
        c_y = c - Y
        lmo_err = 0
        N = 0
        for ii in range(1):
            permutation = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], M):
                indices = permutation[i:i + M]
                K_i = W[np.ix_(indices, indices)] * N2
                C_i = C[np.ix_(indices, indices)]
                c_y_i = c_y[indices]
                b_y = np.linalg.inv(np.eye(C_i.shape[0]) - C_i @ K_i) @ c_y_i
                # print(I_CW_inv.shape,c_y_i.shape)
                lmo_err += b_y.T @ K_i @ b_y
                N += 1
        return lmo_err[0, 0] / N / M**2

    def callback0(params, timer=None):
        global Nfeval, prev_norm, opt_params, opt_test_err
        if Nfeval % 1 == 0:
            params = np.exp(params)
            al, bl = params[:-1], params[-1]
            L = bl * bl * np.exp(
                -L0[0] / al[0] / al[0] / 2) + bl * bl * np.exp(
                    -L0[1] / al[1] / al[1] / 2) + 1e-6 * EYEN
            if nystr:
                alpha = EYEN - eig_vec_K @ np.linalg.inv(
                    eig_vec_K.T @ L @ eig_vec_K / N2 +
                    np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
                alpha = alpha @ W_nystr @ Y * N2
            else:
                LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
                alpha = LWL_inv @ L @ W @ Y
            pred_mean = L @ alpha
            if timer:
                return
            norm = alpha.T @ L @ alpha

        Nfeval += 1
        if prev_norm is not None:
            if norm[0, 0] / prev_norm >= 3:
                if opt_params is None:
                    opt_params = params
                    opt_test_err = ((pred_mean - Y)**2).mean()
                print(True, opt_params, opt_test_err, prev_norm)
                raise Exception

        if prev_norm is None or norm[0, 0] <= prev_norm:
            prev_norm = norm[0, 0]
        opt_params = params
        opt_test_err = ((pred_mean - Y)**2).mean()
        print('params,test_err, norm:', opt_params, opt_test_err, prev_norm)

        ages = np.linspace(
            min(X[:, 0]) - abs(min(X[:, 0])) * 0.05,
            max(X[:, 0]) + abs(max(X[:, 0])) * 0.05, 32)
        vitd = np.linspace(
            min(X[:, 1]) - abs(min(X[:, 1])) * 0.05,
            max(X[:, 1]) + abs(max(X[:, 1])) * 0.05, 64)

        X_mesh, Y_mesh = np.meshgrid(ages, vitd)
        table = bl**2 * np.hstack([
            np.exp(-_sqdist(X_mesh[:, [i]], X[:, [0]]) / al[0]**2 / 2 -
                   _sqdist(Y_mesh[:, [i]], X[:, [1]]) / al[1]**2 / 2) @ alpha
            for i in range(X_mesh.shape[1])
        ])
        maxv = np.max(table[:])
        minv = np.min(table[:])
        fig = plt.figure()
        ax = fig.add_subplot(111)

        # Generate a contour plot
        Y0 = data0[:, [4]]
        X0 = data0[:, [0, 2]]
        Z0 = data0[:, [0, 1]]
        ages = np.linspace(
            min(X0[:, 0]) - abs(min(X0[:, 0])) * 0.05,
            max(X0[:, 0]) + abs(max(X0[:, 0])) * 0.05, 32)
        vitd = np.linspace(
            min(X0[:, 1]) - abs(min(X0[:, 1])) * 0.05,
            max(X0[:, 1]) + abs(max(X0[:, 1])) * 0.05, 64)
        X_mesh, Y_mesh = np.meshgrid(ages, vitd)
        cpf = ax.contourf(X_mesh, Y_mesh, (table - minv) / (maxv - minv))
        # cp = ax.contour(X_mesh, Y_mesh, table)
        plt.colorbar(cpf, ax=ax)
        plt.xlabel('Age', fontsize=12)
        plt.ylabel('Vitamin D', fontsize=12)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        if IV:
            plt.savefig('VitD_IV.pdf', bbox_inches='tight')
        else:
            plt.savefig('VitD.pdf', bbox_inches='tight')
        plt.close('all')

    robjects.r['load'](ROOT_PATH + "/data/VitD.RData")
    data = np.array(robjects.r['VitD']).T

    # plot data
    fig = plt.figure()
    plt.scatter((data[:, 0])[data[:, 4] > 0], (data[:, 2])[data[:, 4] > 0],
                marker='s',
                s=3,
                c='r',
                label='dead')
    plt.scatter((data[:, 0])[data[:, 4] == 0], (data[:, 2])[data[:, 4] == 0],
                marker='o',
                s=1,
                c='b',
                label='alive')
    lgnd = plt.legend()
    lgnd.legendHandles[0]._sizes = [30]
    lgnd.legendHandles[1]._sizes = [30]
    plt.xlabel('Age', fontsize=12)
    plt.ylabel('Vitamin D', fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.savefig('VitD_data.pdf', bbox_inches='tight')
    plt.close('all')

    data0 = data.copy()
    for i in range(data.shape[1]):
        data[:, i] = (data[:, i] - data[:, i].mean()) / data[:, i].std()
    Y = data[:, [4]]
    X = data[:, [0, 2]]
    Z = data[:, [0, 1]]
    t0 = time.time()
    EYEN = np.eye(X.shape[0])
    N2 = X.shape[0]**2
    if IV:
        ak = get_median_inter_mnist(Z)
        W0 = _sqdist(Z, None)
        W = (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) +
             np.exp(-W0 / ak / ak * 50)) / 3 / N2
        del W0
    else:
        W = EYEN / N2
    L0 = np.array([_sqdist(X[:, [i]], None) for i in range(X.shape[1])])
    params0 = np.random.randn(3) / 10
    bounds = None  # [[0.01,10],[0.01,5]]
    if nystr:
        for _ in range(seed + 1):
            random_indices = np.sort(
                np.random.choice(range(W.shape[0]), nystr_M, replace=False))
        eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices)
        inv_eig_val_K = np.diag(1 / eig_val_K / N2)
        W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2
        W_nystr_Y = W_nystr @ Y

    obj_grad = value_and_grad(lambda params: LMO_err(params))
    res = minimize(obj_grad,
                   x0=params0,
                   bounds=bounds,
                   method='L-BFGS-B',
                   jac=True,
                   options={'maxiter': 5000},
                   callback=callback0)
Example #2
0
def experiment(sname, seed, datasize, nystr=False, args=None):
    np.random.seed(1)
    random.seed(1)
    def LMO_err(params, M=10):
        np.random.seed(2)
        random.seed(2)
        al, bl = np.exp(params)
        L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
        if nystr:
            tmp_mat = L @ eig_vec_K
            C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 + inv_eig_val_K) @ tmp_mat.T / N2
            c = C @ W_nystr_Y * N2
        else:
            LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
            C = L @ LWL_inv @ L / N2
            c = C @ W @ Y * N2
        c_y = c - Y
        lmo_err = 0
        N = 0
        for ii in range(1):
            permutation = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], M):
                indices = permutation[i:i + M]
                K_i = W[np.ix_(indices, indices)] * N2
                C_i = C[np.ix_(indices, indices)]
                c_y_i = c_y[indices]
                b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i
                lmo_err += b_y.T @ K_i @ b_y
                N += 1
        return lmo_err[0, 0] / N / M ** 2

    def callback0(params, timer=None):
        global Nfeval, prev_norm, opt_params, opt_test_err
        np.random.seed(3)
        random.seed(3)
        if Nfeval % 1 == 0:
            al, bl = params
            L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
            if nystr:
                alpha = EYEN - eig_vec_K @ np.linalg.inv(
                    eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
                alpha = alpha @ W_nystr @ Y * N2
            else:
                LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
                alpha = LWL_inv @ L @ W @ Y
                # L_W_inv = chol_inv(W*N2+L_inv)
            test_L = bl * bl * np.exp(-test_L0 / al / al / 2)
            pred_mean = test_L @ alpha
            if timer:
                return
            test_err = ((pred_mean - test_Y) ** 2).mean()  # ((pred_mean-test_Y)**2/np.diag(pred_cov)).mean()+(np.log(np.diag(pred_cov))).mean()
            norm = alpha.T @ L @ alpha

        Nfeval += 1
        if prev_norm is not None:
            if norm[0, 0] / prev_norm >= 3:
                if opt_params is None:
                    opt_test_err = test_err
                    opt_params = params
                print(True, opt_params, opt_test_err, prev_norm)
                raise Exception

        if prev_norm is None or norm[0, 0] <= prev_norm:
            prev_norm = norm[0, 0]
        opt_test_err = test_err
        opt_params = params
        print('params,test_err, norm: ', opt_params, opt_test_err, prev_norm)

    def get_causal_effect(params, do_A, w):
        "to be called within experiment function."
        np.random.seed(4)
        random.seed(4)
        al, bl = params
        L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
        if nystr:
            alpha = EYEN - eig_vec_K @ np.linalg.inv(
                eig_vec_K.T @ L @ eig_vec_K / N2 + np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
            alpha = alpha @ W_nystr @ Y * N2
        else:
            LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
            alpha = LWL_inv @ L @ W @ Y
            # L_W_inv = chol_inv(W*N2+L_inv)

        EYhat_do_A = []
        for a in do_A:
            a = np.repeat(a, [w.shape[0]]).reshape(-1, 1)
            w = w.reshape(-1, 1)
            aw = np.concatenate([a, w], axis=-1)
            ate_L0 = _sqdist(aw, X)
            ate_L = bl * bl * np.exp(-ate_L0 / al / al / 2)
            h_out = ate_L @ alpha

            mean_h = np.mean(h_out).reshape(-1, 1)
            EYhat_do_A.append(mean_h)
            print('a = {}, beta_a = {}'.format(np.mean(a), mean_h))

        return np.concatenate(EYhat_do_A)

    # train,dev,test = load_data(ROOT_PATH+'/data/zoo/{}_{}.npz'.format(sname,datasize))

    # X = np.vstack((train.x,dev.x))
    # Y = np.vstack((train.y,dev.y))
    # Z = np.vstack((train.z,dev.z))
    # test_X = test.x
    # test_Y = test.g
    t1 = time.time()
    train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname + '/main_{}.npz'.format(args.sem))
    # train, dev, test = train[:300], dev[:100], test[:100]
    t2 = time.time()
    print('t2 - t1 = ', t2 - t1)
    Y = np.concatenate((train.y, dev.y), axis=0).reshape(-1, 1)
    # test_Y = test.y
    AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w)
    AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w)
    AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w)

    X, Z = np.concatenate((AW_train, AW_dev), axis=0), np.concatenate((AZ_train, AZ_dev), axis=0)
    test_X, test_Y = AW_test, test.y.reshape(-1, 1)  # TODO: is test.g just test.y?

    t3 = time.time()
    print('t3 - t2', t3-t2)
    EYEN = np.eye(X.shape[0])
    # ak0, ak1 = get_median_inter_mnist(Z[:, 0:1]), get_median_inter_mnist(Z[:, 1:2])
    ak = get_median_inter_mnist(Z)
    N2 = X.shape[0] ** 2
    W0 = _sqdist(Z, None)
    print('av kernel indicator: ', args.av_kernel)
    # W = np.exp(-W0 / ak0 / ak0 / 2) / N2 if not args.av_kernel \
    #     else (np.exp(-W0 / ak0 / ak0 / 2) + np.exp(-W0 / ak0 / ak0 / 200) + np.exp(-W0 / ak0 / ak0 * 50)) / 3 / N2
    W = np.exp(-W0 / ak / ak / 2) / N2 if not args.av_kernel \
        else (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) + np.exp(-W0 / ak / ak * 50)) / 3 / N2

    del W0
    L0, test_L0 = _sqdist(X, None), _sqdist(test_X, X)
    t4 = time.time()
    print('t4 - t3', t4-t3)
    # measure time
    # callback0(np.random.randn(2)/10,True)
    # np.save(ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + '/LMO_errs_{}_nystr_{}_time.npy'.format(seed,train.x.shape[0]),time.time()-t0)
    # return

    # params0 = np.random.randn(2)  # /10
    params0 = np.array([1., 1.])
    print('starting param: ', params0)
    bounds = None  # [[0.01,10],[0.01,5]]
    if nystr:
        for _ in range(seed + 1):
            random_indices = np.sort(np.random.choice(range(W.shape[0]), nystr_M, replace=False))
        eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices)
        inv_eig_val_K = np.diag(1 / eig_val_K / N2)
        W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2
        W_nystr_Y = W_nystr @ Y

    t5 = time.time()
    print('t5 - t4', t5-t4)
    obj_grad = value_and_grad(lambda params: LMO_err(params))
    try:
        res = minimize(obj_grad, x0=params0, bounds=bounds, method='L-BFGS-B', jac=True, options={'maxiter': 5000},
                   callback=callback0)
    # res stands for results (not residuals!).
    except Exception as e:
        print(e)

    PATH = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/"
    if not os.path.exists(PATH+str(date.today())):
        os.mkdir(PATH + str(date.today()))

    assert opt_params is not None
    params = opt_params
    do_A = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['do_A']
    EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname + '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A']
    w_sample = train.w
    EYhat_do_A = get_causal_effect(params=params, do_A=do_A, w=w_sample)
    plt.figure()
    plt.plot([i + 1 for i in range(20)], EYhat_do_A)
    plt.xlabel('A')
    plt.ylabel('EYdoA-est')
    plt.savefig(
        os.path.join(PATH, str(date.today()), 'causal_effect_estimates_nystr_{}'.format(AW_train.shape[0]) + '.png'))
    plt.close()
    print('ground truth ate: ', EY_do_A_gt)
    visualise_ATEs(EY_do_A_gt, EYhat_do_A,
                   x_name='E[Y|do(A)] - gt',
                   y_name='beta_A',
                   save_loc=os.path.join(PATH, str(date.today())) + '/',
                   save_name='ate_{}_nystr.png'.format(AW_train.shape[0]))
    causal_effect_mean_abs_err = np.mean(np.abs(EY_do_A_gt - EYhat_do_A))
    causal_effect_mae_file = open(os.path.join(PATH, str(date.today()), "ate_mae_{}_nystrom.txt".format(AW_train.shape[0])),
                                  "a")
    causal_effect_mae_file.write("mae_: {}\n".format(causal_effect_mean_abs_err))
    causal_effect_mae_file.close()

    os.makedirs(PATH, exist_ok=True)
    np.save(os.path.join(PATH, str(date.today()), 'LMO_errs_{}_nystr_{}.npy'.format(seed, AW_train.shape[0])), [opt_params, prev_norm, opt_test_err])
Example #3
0
def run_experiment_nn(sname, datasize, indices=[], seed=527, training=True):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if len(indices) == 2:
        lr_id, dw_id = indices
    elif len(indices) == 3:
        lr_id, dw_id, W_id = indices
    # load data
    folder = ROOT_PATH + "/MMR_IVs/results/" + sname + "/"
    os.makedirs(folder, exist_ok=True)

    train, dev, test = load_data(ROOT_PATH + "/data/" + sname + '/main.npz',
                                 Torch=True)
    X, Z, Y = torch.cat((train.x, dev.x), dim=0).float(), torch.cat(
        (train.z, dev.z), dim=0).float(), torch.cat((train.y, dev.y),
                                                    dim=0).float()
    test_X, test_G = test.x.float(), test.g.float()
    n_train = train.x.shape[0]
    # training settings
    n_epochs = 1000
    batch_size = 1000 if train.x.shape[0] > 1000 else train.x.shape[0]

    # kernel
    kernel = Kernel('rbf', Torch=True)
    if Z.shape[1] < 5:
        a = get_median_inter_mnist(train.z)
    else:
        # a = get_median_inter_mnist(train.z)
        # np.save('../mnist_precomp/{}_ak.npy'.format(sname),a)
        a = np.load(ROOT_PATH + '/mnist_precomp/{}_ak.npy'.format(sname))
    a = torch.tensor(a).float()
    # training loop
    lrs = [2e-4, 1e-4, 5e-5]  # [3,5]
    decay_weights = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6]  # [11,5]

    def my_loss(output, target, indices, K):
        d = output - target
        if indices is None:
            W = K
        else:
            W = K[indices[:, None], indices]
            # print((kernel(Z[indices],None,a,1)+kernel(Z[indices],None,a/10,1)+kernel(Z[indices],None,a*10,1))/3-W)
        loss = d.T @ W @ d / (d.shape[0])**2
        return loss[0, 0]

    def fit(x,
            y,
            z,
            dev_x,
            dev_y,
            dev_z,
            a,
            lr,
            decay_weight,
            n_epochs=n_epochs):
        if 'mnist' in sname:
            train_K = torch.eye(x.shape[0])
        else:
            train_K = (kernel(z, None, a, 1) + kernel(z, None, a / 10, 1) +
                       kernel(z, None, a * 10, 1)) / 3
        if dev_z is not None:
            if 'mnist' in sname:
                dev_K = torch.eye(x.shape[0])
            else:
                dev_K = (kernel(dev_z, None, a, 1) +
                         kernel(dev_z, None, a / 10, 1) +
                         kernel(dev_z, None, a * 10, 1)) / 3
        n_data = x.shape[0]
        net = FCNN(x.shape[1]) if sname not in ['mnist_x', 'mnist_xz'
                                                ] else CNN()
        es = EarlyStopping(patience=5)  # 10 for small
        optimizer = optim.Adam(list(net.parameters()),
                               lr=lr,
                               weight_decay=decay_weight)
        # optimizer = optim.SGD(list(net.parameters()),lr=1e-1, momentum=0.9)
        # optimizer = optim.Adadelta(list(net.parameters()))

        for epoch in range(n_epochs):
            permutation = torch.randperm(n_data)

            for i in range(0, n_data, batch_size):
                indices = permutation[i:i + batch_size]
                batch_x, batch_y = x[indices], y[indices]

                # training loop
                def closure():
                    optimizer.zero_grad()
                    pred_y = net(batch_x)
                    loss = my_loss(pred_y, batch_y, indices, train_K)
                    loss.backward()
                    return loss

                optimizer.step(closure)  # Does the update
            if epoch % 5 == 0 and epoch >= 50 and dev_x is not None:  # 5, 10 for small # 5,50 for large
                g_pred = net(test_X)
                test_err = ((g_pred - test_G)**2).mean()
                if epoch == 50 and 'mnist' in sname:
                    if z.shape[1] > 100:
                        train_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_train_K0.npy'.format(sname))
                        train_K = (torch.exp(-train_K / a**2 / 2) +
                                   torch.exp(-train_K / a**2 * 50) +
                                   torch.exp(-train_K / a**2 / 200)) / 3
                        dev_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_dev_K0.npy'.format(sname))
                        dev_K = (torch.exp(-dev_K / a**2 / 2) +
                                 torch.exp(-dev_K / a**2 * 50) +
                                 torch.exp(-dev_K / a**2 / 200)) / 3
                    else:
                        train_K = (kernel(z, None, a, 1) +
                                   kernel(z, None, a / 10, 1) +
                                   kernel(z, None, a * 10, 1)) / 3
                        dev_K = (kernel(dev_z, None, a, 1) +
                                 kernel(dev_z, None, a / 10, 1) +
                                 kernel(dev_z, None, a * 10, 1)) / 3

                dev_err = my_loss(net(dev_x), dev_y, None, dev_K)
                print('test', test_err, 'dev', dev_err)
                if es.step(dev_err):
                    break
        return es.best, epoch, net

    if training is True:
        print('training')
        for rep in range(10):
            save_path = os.path.join(
                folder,
                'mmr_iv_nn_{}_{}_{}_{}.npz'.format(rep, lr_id, dw_id,
                                                   train.x.shape[0]))
            # if os.path.exists(save_path):
            #    continue
            lr, dw = lrs[lr_id], decay_weights[dw_id]
            print('lr, dw', lr, dw)
            t0 = time.time()
            err, _, net = fit(X[:n_train], Y[:n_train], Z[:n_train],
                              X[n_train:], Y[n_train:], Z[n_train:], a, lr, dw)
            t1 = time.time() - t0
            np.save(
                folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                    rep, lr_id, dw_id, train.x.shape[0]), t1)
            g_pred = net(test_X).detach().numpy()
            test_err = ((g_pred - test_G.numpy())**2).mean()
            np.savez(save_path,
                     err=err.detach().numpy(),
                     lr=lr,
                     dw=dw,
                     g_pred=g_pred,
                     test_err=test_err)
    else:
        print('test')
        opt_res = []
        times = []
        for rep in range(10):
            res_list = []
            other_list = []
            times2 = []
            for lr_id in range(len(lrs)):
                for dw_id in range(len(decay_weights)):
                    load_path = os.path.join(
                        folder, 'mmr_iv_nn_{}_{}_{}_{}.npz'.format(
                            rep, lr_id, dw_id, datasize))
                    if os.path.exists(load_path):
                        res = np.load(load_path)
                        res_list += [res['err'].astype(float)]
                        other_list += [[
                            res['lr'].astype(float), res['dw'].astype(float),
                            res['test_err'].astype(float)
                        ]]
                    time_path = folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                        rep, lr_id, dw_id, datasize)
                    if os.path.exists(time_path):
                        t = np.load(time_path)
                        times2 += [t]
            res_list = np.array(res_list)
            other_list = np.array(other_list)
            other_list = other_list[res_list > 0]
            res_list = res_list[res_list > 0]
            optim_id = np.argsort(res_list)[0]  # np.argmin(res_list)
            print(rep, '--', other_list[optim_id], np.min(res_list))
            opt_res += [other_list[optim_id][-1]]
            # times += [times2[optim_id]]
            # lr,dw = [torch.from_numpy(e).float() for e in params_list[optim_id]]
            # _,_,net = fit(X[:2000],Y[:2000],Z[:2000],X[2000:],Y[2000:],Z[2000:],a,lr,dw)
            # g_pred = net(test_X).detach().numpy()
            # test_err = ((g_pred-test_G.numpy())**2).mean()
            # print(test_err)
            # np.savez(save_path,g_pred=g_pred,g_true=test.g,x=test.w)
        print('time: ', np.mean(times), np.std(times))
        print(np.mean(opt_res), np.std(opt_res))
Example #4
0
def run_experiment_nn(sname,
                      datasize,
                      indices=[],
                      seed=527,
                      training=True,
                      args=None):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if len(indices) == 2:
        lr_id, dw_id = indices
    elif len(indices) == 3:
        lr_id, dw_id, W_id = indices
    # load data
    folder = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/"
    os.makedirs(folder, exist_ok=True)

    train, dev, test = load_data(ROOT_PATH + "/data/zoo/" + sname +
                                 '/main_{}.npz'.format(args.sem),
                                 Torch=True)
    Y = torch.cat((train.y, dev.y), dim=0).float()
    AZ_train, AW_train = bundle_az_aw(train.a, train.z, train.w, Torch=True)
    AZ_test, AW_test = bundle_az_aw(test.a, test.z, test.w, Torch=True)
    AZ_dev, AW_dev = bundle_az_aw(dev.a, dev.z, test.w, Torch=True)

    X, Z = torch.cat((AW_train, AW_dev), dim=0).float(), torch.cat(
        (AZ_train, AZ_dev), dim=0).float()
    test_X, test_Y = AW_test.float(), test.y.float(
    )  # TODO: is test.g just test.y?
    n_train = train.a.shape[0]
    # training settings
    n_epochs = 1000
    batch_size = 1000 if train.a.shape[0] > 1000 else train.a.shape[0]

    # load expectation eval data
    axzy = np.load(ROOT_PATH + "/data/zoo/" + sname +
                   '/cond_exp_metric_{}.npz'.format(args.sem))['axzy']
    w_samples = np.load(
        ROOT_PATH + "/data/zoo/" + sname +
        '/cond_exp_metric_{}.npz'.format(args.sem))['w_samples']
    y_samples = np.load(
        ROOT_PATH + "/data/zoo/" + sname +
        '/cond_exp_metric_{}.npz'.format(args.sem))['y_samples']
    y_axz = axzy[:, -1]
    ax = axzy[:, :2]

    # kernel
    kernel = Kernel('rbf', Torch=True)
    a = get_median_inter_mnist(AZ_train)
    a = torch.tensor(a).float()
    # training loop
    lrs = [2e-4, 1e-4, 5e-5]  # [3,5]
    decay_weights = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6]  # [11,5]

    def my_loss(output, target, indices, K):
        d = output - target
        if indices is None:
            W = K
        else:
            W = K[indices[:, None], indices]
            # print((kernel(Z[indices],None,a,1)+kernel(Z[indices],None,a/10,1)+kernel(Z[indices],None,a*10,1))/3-W)
        loss = d.T @ W @ d / (d.shape[0])**2
        return loss[0, 0]

    def conditional_expected_loss(net, ax, w_samples, y_samples, y_axz, x_on):
        if not x_on:
            ax = ax[:, 0:1]
        num_reps = w_samples.shape[1]
        assert len(ax.shape) == 2
        assert ax.shape[1] < 3
        assert ax.shape[0] == w_samples.shape[0]
        print('number of points: ', w_samples.shape[0])

        ax_rep = np.repeat(ax, [num_reps], axis=0)
        assert ax_rep.shape[0] == (w_samples.shape[1] * ax.shape[0])

        w_samples_flat = w_samples.flatten().reshape(-1, 1)
        nn_inp_np = np.concatenate([ax_rep, w_samples_flat], axis=-1)
        # print('nn_inp shape: ', nn_inp_np.shape)
        nn_inp = torch.as_tensor(nn_inp_np).float()
        nn_out = net(nn_inp).detach().cpu().numpy()
        nn_out = nn_out.reshape([-1, w_samples.shape[1]])
        y_axz_recon = np.mean(nn_out, axis=1)
        assert y_axz_recon.shape[0] == y_axz.shape[0]
        mean_abs_error = np.mean(np.abs(y_axz - y_axz_recon))

        # for debugging compute the mse between y samples and h
        y_samples_flat = y_samples.flatten()
        mse = np.mean((y_samples_flat - nn_out.flatten())**2)

        return mean_abs_error, mse

    def fit(x,
            y,
            z,
            dev_x,
            dev_y,
            dev_z,
            a,
            lr,
            decay_weight,
            ax,
            y_axz,
            w_samples,
            n_epochs=n_epochs):
        if 'mnist' in sname:
            train_K = torch.eye(x.shape[0])
        else:
            train_K = (kernel(z, None, a, 1) + kernel(z, None, a / 10, 1) +
                       kernel(z, None, a * 10, 1)) / 3
        if dev_z is not None:
            if 'mnist' in sname:
                dev_K = torch.eye(x.shape[0])
            else:
                dev_K = (kernel(dev_z, None, a, 1) +
                         kernel(dev_z, None, a / 10, 1) +
                         kernel(dev_z, None, a * 10, 1)) / 3
        n_data = x.shape[0]
        net = FCNN(x.shape[1]) if sname not in ['mnist_x', 'mnist_xz'
                                                ] else CNN()
        es = EarlyStopping(patience=10)  # 10 for small
        optimizer = optim.Adam(list(net.parameters()),
                               lr=lr,
                               weight_decay=decay_weight)

        test_errs, dev_errs, exp_errs, mse_s = [], [], [], []

        for epoch in range(n_epochs):
            permutation = torch.randperm(n_data)

            for i in range(0, n_data, batch_size):
                indices = permutation[i:i + batch_size]
                batch_x, batch_y = x[indices], y[indices]

                # training loop
                def closure():
                    optimizer.zero_grad()
                    pred_y = net(batch_x)
                    loss = my_loss(pred_y, batch_y, indices, train_K)
                    loss.backward()
                    return loss

                optimizer.step(closure)  # Does the update
            if epoch % 5 == 0 and epoch >= 50 and dev_x is not None:  # 5, 10 for small # 5,50 for large
                g_pred = net(
                    test_X
                )  # TODO: is it supposed to be test_X here? A: yes I think so.
                test_err = ((g_pred - test_Y)**2).mean(
                )  # TODO: why isn't this loss reweighted? A: because it is supposed to measure the agreement between prediction and labels.
                if epoch == 50 and 'mnist' in sname:
                    if z.shape[1] > 100:
                        train_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_train_K0.npy'.format(sname))
                        train_K = (torch.exp(-train_K / a**2 / 2) +
                                   torch.exp(-train_K / a**2 * 50) +
                                   torch.exp(-train_K / a**2 / 200)) / 3
                        dev_K = np.load(
                            ROOT_PATH +
                            '/mnist_precomp/{}_dev_K0.npy'.format(sname))
                        dev_K = (torch.exp(-dev_K / a**2 / 2) +
                                 torch.exp(-dev_K / a**2 * 50) +
                                 torch.exp(-dev_K / a**2 / 200)) / 3
                    else:
                        train_K = (kernel(z, None, a, 1) +
                                   kernel(z, None, a / 10, 1) +
                                   kernel(z, None, a * 10, 1)) / 3
                        dev_K = (kernel(dev_z, None, a, 1) +
                                 kernel(dev_z, None, a / 10, 1) +
                                 kernel(dev_z, None, a * 10, 1)) / 3

                dev_err = my_loss(net(dev_x), dev_y, None, dev_K)
                err_in_expectation, mse = conditional_expected_loss(
                    net=net,
                    ax=ax,
                    w_samples=w_samples,
                    y_samples=y_samples,
                    y_axz=y_axz,
                    x_on=False)
                print('test', test_err, 'dev', dev_err, 'err_in_expectation',
                      err_in_expectation, 'mse: ', mse)
                test_errs.append(test_err)
                dev_errs.append(dev_err)
                exp_errs.append(err_in_expectation)
                mse_s.append(mse)

                if es.step(dev_err):
                    break
            losses = {
                'test': test_errs,
                'dev': dev_errs,
                'exp': exp_errs,
                'mse_': mse_s
            }
        return es.best, epoch, net, losses

    def get_causal_effect(net, do_A, w):
        """
        :param net: FCNN object
        :param do_A: a numpy array of interventions, size = B_a
        :param w: a torch tensor of w samples, size = B_w
        :return: a numpy array of interventional parameters
        """
        net.eval()
        # raise ValueError('have not tested get_causal_effect.')
        EYhat_do_A = []
        for a in do_A:
            a = np.repeat(a, [w.shape[0]]).reshape(-1, 1)
            a_tensor = torch.as_tensor(a).float()
            w = w.reshape(-1, 1).float()
            aw = torch.cat([a_tensor, w], dim=-1)
            aw_tensor = torch.tensor(aw)
            mean_h = torch.mean(net(aw_tensor)).reshape(-1, 1)
            EYhat_do_A.append(mean_h)
            print('a = {}, beta_a = {}'.format(np.mean(a), mean_h))
        return torch.cat(EYhat_do_A).detach().cpu().numpy()

    if training is True:
        print('training')
        for rep in range(3):
            print('*******REP: {}'.format(rep))
            save_path = os.path.join(
                folder,
                'mmr_iv_nn_{}_{}_{}_{}.npz'.format(rep, lr_id, dw_id,
                                                   AW_train.shape[0]))
            # if os.path.exists(save_path):
            #    continue
            lr, dw = lrs[lr_id], decay_weights[dw_id]
            print('lr, dw', lr, dw)
            t0 = time.time()
            err, _, net, losses = fit(X[:n_train],
                                      Y[:n_train],
                                      Z[:n_train],
                                      X[n_train:],
                                      Y[n_train:],
                                      Z[n_train:],
                                      a,
                                      lr,
                                      dw,
                                      ax=ax,
                                      y_axz=y_axz,
                                      w_samples=w_samples)
            t1 = time.time() - t0
            np.save(
                folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                    rep, lr_id, dw_id, AW_train.shape[0]), t1)
            g_pred = net(test_X).detach().numpy()
            test_err = ((g_pred - test_Y.numpy())**2).mean()
            np.savez(save_path,
                     err=err.detach().numpy(),
                     lr=lr,
                     dw=dw,
                     g_pred=g_pred,
                     test_err=test_err)

            # make loss curves
            for (name, ylabel) in [('test', 'test av ||y - h||^2'),
                                   ('dev', 'R_V'), ('exp', 'E[y-h|a,z,x]'),
                                   ('mse_', 'mse_alternative_sim')]:
                errs = losses[name]
                stps = [50 + i * 5 for i in range(len(errs))]
                plt.figure()
                plt.plot(stps, errs)
                plt.xlabel('epoch')
                plt.ylabel(ylabel)
                plt.savefig(
                    os.path.join(
                        folder, name + '_{}_{}_{}_{}'.format(
                            rep, lr_id, dw_id, AW_train.shape[0]) + '.png'))
                plt.close()

            # do causal effect estimates
            do_A = np.load(ROOT_PATH + "/data/zoo/" + sname +
                           '/do_A_{}.npz'.format(args.sem))['do_A']
            EY_do_A_gt = np.load(ROOT_PATH + "/data/zoo/" + sname +
                                 '/do_A_{}.npz'.format(args.sem))['gt_EY_do_A']
            w_sample = train.w
            EYhat_do_A = get_causal_effect(net, do_A=do_A, w=w_sample)
            plt.figure()
            plt.plot([i + 1 for i in range(20)], EYhat_do_A)
            plt.xlabel('A')
            plt.ylabel('EYdoA-est')
            plt.savefig(
                os.path.join(
                    folder, 'causal_effect_estimates_{}_{}_{}'.format(
                        lr_id, dw_id, AW_train.shape[0]) + '.png'))
            plt.close()

            print('ground truth ate: ', EY_do_A_gt)
            visualise_ATEs(EY_do_A_gt,
                           EYhat_do_A,
                           x_name='E[Y|do(A)] - gt',
                           y_name='beta_A',
                           save_loc=folder,
                           save_name='ate_{}_{}_{}_{}.png'.format(
                               rep, lr_id, dw_id, AW_train.shape[0]))
            causal_effect_mean_abs_err = np.mean(
                np.abs(EY_do_A_gt - EYhat_do_A))
            causal_effect_mae_file = open(
                os.path.join(
                    folder,
                    "ate_mae_{}_{}_{}.txt".format(lr_id, dw_id,
                                                  AW_train.shape[0])), "a")
            causal_effect_mae_file.write("mae_rep_{}: {}\n".format(
                rep, causal_effect_mean_abs_err))
            causal_effect_mae_file.close()

    else:
        print('test')
        opt_res = []
        times = []
        for rep in range(10):
            res_list = []
            other_list = []
            times2 = []
            for lr_id in range(len(lrs)):
                for dw_id in range(len(decay_weights)):
                    load_path = os.path.join(
                        folder, 'mmr_iv_nn_{}_{}_{}_{}.npz'.format(
                            rep, lr_id, dw_id, datasize))
                    if os.path.exists(load_path):
                        res = np.load(load_path)
                        res_list += [res['err'].astype(float)]
                        other_list += [[
                            res['lr'].astype(float), res['dw'].astype(float),
                            res['test_err'].astype(float)
                        ]]
                    time_path = folder + 'mmr_iv_nn_{}_{}_{}_{}_time.npy'.format(
                        rep, lr_id, dw_id, datasize)
                    if os.path.exists(time_path):
                        t = np.load(time_path)
                        times2 += [t]
            res_list = np.array(res_list)
            other_list = np.array(other_list)
            other_list = other_list[res_list > 0]
            res_list = res_list[res_list > 0]
            optim_id = np.argsort(res_list)[0]  # np.argmin(res_list)
            print(rep, '--', other_list[optim_id], np.min(res_list))
            opt_res += [other_list[optim_id][-1]]
        print('time: ', np.mean(times), np.std(times))
        print(np.mean(opt_res), np.std(opt_res))
def experiment(sname, seed, nystr=True):
    def LMO_err(params, M=2, verbal=False):
        global Nfeval
        params = np.exp(params)
        al, bl = params[:-1], params[
            -1]  # params[:int(n_params/2)], params[int(n_params/2):] #  [np.exp(e) for e in params]
        if train.x.shape[1] < 5:
            train_L = bl**2 * np.exp(-train_L0 / al**2 / 2) + 1e-4 * EYEN
        else:
            train_L, dev_L = 0, 0
            for i in range(len(al)):
                train_L += train_L0[i] / al[i]**2
            train_L = bl * bl * np.exp(-train_L / 2) + 1e-4 * EYEN

        tmp_mat = train_L @ eig_vec_K
        C = train_L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 +
                                              inv_eig_val) @ tmp_mat.T / N2
        c = C @ W_nystr_Y * N2
        c_y = c - train.y
        lmo_err = 0
        N = 0
        for ii in range(1):
            permutation = np.random.permutation(train.x.shape[0])
            for i in range(0, train.x.shape[0], M):
                indices = permutation[i:i + M]
                K_i = train_W[np.ix_(indices, indices)] * N2
                C_i = C[np.ix_(indices, indices)]
                c_y_i = c_y[indices]
                b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i
                lmo_err += b_y.T @ K_i @ b_y
                N += 1
        return lmo_err[0, 0] / M**2

    def callback0(params):
        global Nfeval, prev_norm, opt_params, opt_test_err
        if Nfeval % 1 == 0:
            params = np.exp(params)
            print('params:', params)
            al, bl = params[:-1], params[-1]

            if train.x.shape[1] < 5:
                train_L = bl**2 * np.exp(-train_L0 / al**2 / 2) + 1e-4 * EYEN
                test_L = bl**2 * np.exp(-test_L0 / al**2 / 2)
            else:
                train_L, test_L = 0, 0
                for i in range(len(al)):
                    train_L += train_L0[i] / al[i]**2
                    test_L += test_L0[i] / al[i]**2
                train_L = bl * bl * np.exp(-train_L / 2) + 1e-4 * EYEN
                test_L = bl * bl * np.exp(-test_L / 2)

            if nystr:
                tmp_mat = eig_vec_K.T @ train_L
                alpha = EYEN - eig_vec_K @ np.linalg.inv(
                    tmp_mat @ eig_vec_K / N2 + inv_eig_val) @ tmp_mat / N2
                alpha = alpha @ W_nystr_Y * N2
            else:
                LWL_inv = chol_inv(train_L @ train_W @ train_L + train_L / N2 +
                                   JITTER * EYEN)
                alpha = LWL_inv @ train_L @ train_W @ train.y
            pred_mean = test_L @ alpha
            test_err = ((pred_mean - test.g)**2).mean()
            norm = alpha.T @ train_L @ alpha
        Nfeval += 1
        if prev_norm is not None:
            if norm[0, 0] / prev_norm >= 3:
                if opt_test_err is None:
                    opt_test_err = test_err
                    opt_params = params
                print(True, opt_params, opt_test_err, prev_norm, norm[0, 0])
                raise Exception

        if prev_norm is None or norm[0, 0] <= prev_norm:
            prev_norm = norm[0, 0]
        opt_test_err = test_err
        opt_params = params
        print(True, opt_params, opt_test_err, prev_norm, norm[0, 0])

    train, dev, test = load_data(ROOT_PATH + '/data/' + sname + '/main.npz')
    del dev

    # avoid same indices when run on the cluster
    for _ in range(seed + 1):
        random_indices = np.sort(
            np.random.choice(range(train.x.shape[0]), nystr_M, replace=False))

    EYEN = np.eye(train.x.shape[0])
    N2 = train.x.shape[0]**2

    # precompute to save time on parallized computation
    if train.z.shape[1] < 5:
        ak = get_median_inter_mnist(train.z)
    else:
        ak = np.load(ROOT_PATH + '/mnist_precomp/{}_ak.npy'.format(sname))
    train_W = np.load(ROOT_PATH +
                      '/mnist_precomp/{}_train_K0.npy'.format(sname))
    train_W = (np.exp(-train_W / ak / ak / 2) + np.exp(
        -train_W / ak / ak / 200) + np.exp(-train_W / ak / ak * 50)) / 3 / N2
    if train.x.shape[1] < 5:
        train_L0 = _sqdist(train.x, None)
        test_L0 = _sqdist(test.x, train.x)
    else:
        L0s = np.load(ROOT_PATH + '/mnist_precomp/{}_Ls.npz'.format(sname))
        train_L0 = L0s['train_L0']
        # dev_L0 = L0s['dev_L0']
        test_L0 = L0s['test_L0']
        del L0s
    if train.x.shape[1] < 5:
        params0 = np.random.randn(2) * 0.1
    else:
        params0 = np.random.randn(len(train_L0) + 1) * 0.1
    bounds = None
    eig_val_K, eig_vec_K = nystrom_decomp(train_W * N2, random_indices)
    W_nystr_Y = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T @ train.y / N2
    inv_eig_val = np.diag(1 / eig_val_K / N2)
    obj_grad = value_and_grad(lambda params: LMO_err(params))
    res = minimize(obj_grad,
                   x0=params0,
                   bounds=bounds,
                   method='L-BFGS-B',
                   jac=True,
                   options={
                       'maxiter': 5000,
                       'disp': True,
                       'ftol': 0
                   },
                   callback=callback0)
    PATH = ROOT_PATH + "/MMR_IVs/results/" + sname + "/"
    os.makedirs(PATH, exist_ok=True)
    np.save(PATH + 'LMO_errs_{}_nystr.npy'.format(seed),
            [opt_params, prev_norm, opt_test_err])
Example #6
0
def experiment(sname, seed, datasize, nystr=False):
    def LMO_err(params, M=2):
        al, bl = np.exp(params)
        L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
        if nystr:
            tmp_mat = L @ eig_vec_K
            C = L - tmp_mat @ np.linalg.inv(eig_vec_K.T @ tmp_mat / N2 +
                                            inv_eig_val_K) @ tmp_mat.T / N2
            c = C @ W_nystr_Y * N2
        else:
            LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
            C = L @ LWL_inv @ L / N2
            c = C @ W @ Y * N2
        c_y = c - Y
        lmo_err = 0
        N = 0
        for ii in range(1):
            permutation = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], M):
                indices = permutation[i:i + M]
                K_i = W[np.ix_(indices, indices)] * N2
                C_i = C[np.ix_(indices, indices)]
                c_y_i = c_y[indices]
                b_y = np.linalg.inv(np.eye(M) - C_i @ K_i) @ c_y_i
                lmo_err += b_y.T @ K_i @ b_y
                N += 1
        return lmo_err[0, 0] / N / M**2

    def callback0(params, timer=None):
        global Nfeval, prev_norm, opt_params, opt_test_err
        if Nfeval % 1 == 0:
            al, bl = np.exp(params)
            L = bl * bl * np.exp(-L0 / al / al / 2) + 1e-6 * EYEN
            if nystr:
                alpha = EYEN - eig_vec_K @ np.linalg.inv(
                    eig_vec_K.T @ L @ eig_vec_K / N2 +
                    np.diag(1 / eig_val_K / N2)) @ eig_vec_K.T @ L / N2
                alpha = alpha @ W_nystr @ Y * N2
            else:
                LWL_inv = chol_inv(L @ W @ L + L / N2 + JITTER * EYEN)
                alpha = LWL_inv @ L @ W @ Y
                # L_W_inv = chol_inv(W*N2+L_inv)
            test_L = bl * bl * np.exp(-test_L0 / al / al / 2)
            pred_mean = test_L @ alpha
            if timer:
                return
            test_err = ((pred_mean - test_G)**2).mean(
            )  # ((pred_mean-test_G)**2/np.diag(pred_cov)).mean()+(np.log(np.diag(pred_cov))).mean()
            norm = alpha.T @ L @ alpha

        Nfeval += 1
        if prev_norm is not None:
            if norm[0, 0] / prev_norm >= 3:
                if opt_params is None:
                    opt_test_err = test_err
                    opt_params = params
                print(True, opt_params, opt_test_err, prev_norm)
                raise Exception

        if prev_norm is None or norm[0, 0] <= prev_norm:
            prev_norm = norm[0, 0]
        opt_test_err = test_err
        opt_params = params
        print('params,test_err, norm: ', opt_params, opt_test_err, prev_norm)

    train, dev, test = load_data(ROOT_PATH +
                                 '/data/zoo/{}_{}.npz'.format(sname, datasize))

    X = np.vstack((train.x, dev.x))
    Y = np.vstack((train.y, dev.y))
    Z = np.vstack((train.z, dev.z))
    test_X = test.x
    test_G = test.g

    t0 = time.time()
    EYEN = np.eye(X.shape[0])
    ak = get_median_inter_mnist(Z)
    N2 = X.shape[0]**2
    W0 = _sqdist(Z, None)
    W = (np.exp(-W0 / ak / ak / 2) + np.exp(-W0 / ak / ak / 200) +
         np.exp(-W0 / ak / ak * 50)) / 3 / N2
    del W0
    L0, test_L0 = _sqdist(X, None), _sqdist(test_X, X)

    # measure time
    # callback0(np.random.randn(2)/10,True)
    # np.save(ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + '/LMO_errs_{}_nystr_{}_time.npy'.format(seed,train.x.shape[0]),time.time()-t0)
    # return

    params0 = np.random.randn(2) / 10
    bounds = None  # [[0.01,10],[0.01,5]]
    if nystr:
        for _ in range(seed + 1):
            random_indices = np.sort(
                np.random.choice(range(W.shape[0]), nystr_M, replace=False))
        eig_val_K, eig_vec_K = nystrom_decomp(W * N2, random_indices)
        inv_eig_val_K = np.diag(1 / eig_val_K / N2)
        W_nystr = eig_vec_K @ np.diag(eig_val_K) @ eig_vec_K.T / N2
        W_nystr_Y = W_nystr @ Y

    obj_grad = value_and_grad(lambda params: LMO_err(params))
    try:
        res = minimize(obj_grad,
                       x0=params0,
                       bounds=bounds,
                       method='L-BFGS-B',
                       jac=True,
                       options={'maxiter': 5000},
                       callback=callback0)
    except Exception as e:
        print(e)
    PATH = ROOT_PATH + "/MMR_IVs/results/zoo/" + sname + "/"
    os.makedirs(PATH, exist_ok=True)
    np.save(PATH + 'LMO_errs_{}_nystr_{}.npy'.format(seed, train.x.shape[0]),
            [opt_params, prev_norm, opt_test_err])