Exemple #1
0
    def _test_u_distance_correlation_vector_generic(self,
                                                    vector_type=None,
                                                    type_cov=None,
                                                    type_cor=None):
        """
        Auxiliar function for testing U-distance correlation in vectors.

        This function is provided to check that the results are the
        same with different dtypes, but that the dtype of the result is
        the right one.
        """
        if type_cov is None:
            type_cov = vector_type
        if type_cor is None:
            type_cor = vector_type

        arr1 = np.array([
            vector_type(1),
            vector_type(2),
            vector_type(3),
            vector_type(4),
            vector_type(5),
            vector_type(6)
        ])
        arr2 = np.array([
            vector_type(1),
            vector_type(7),
            vector_type(5),
            vector_type(5),
            vector_type(6),
            vector_type(2)
        ])

        covariance = dcor.u_distance_covariance_sqr(arr1, arr2)
        self.assertIsInstance(covariance, type_cov)
        self.assertAlmostEqual(covariance, type_cov(-0.88889), places=5)

        correlation = dcor.u_distance_correlation_sqr(arr1, arr2)
        self.assertIsInstance(correlation, type_cor)
        self.assertAlmostEqual(correlation, type_cor(-0.41613), places=5)

        covariance = dcor.u_distance_covariance_sqr(arr1, arr1)
        self.assertIsInstance(covariance, type_cov)
        self.assertAlmostEqual(covariance, type_cov(1.5556), places=4)

        correlation = dcor.u_distance_correlation_sqr(arr1, arr1)
        self.assertIsInstance(correlation, type_cor)
        self.assertAlmostEqual(correlation, type_cor(1), places=5)
Exemple #2
0
def compute_correlation_matrix_from_hypergraph(hypergraph, time_series, delay=0, savefigure=None):
    correlation_matrix = np.zeros(hypergraph.shape)

    if delay == 0:
        k_circular = []
        k_circular.append(0)
    else:
        k_circular = range(-1 * delay, delay + 1, 1)

    hypergraph = hypergraph.astype(bool)

    for i in range(hypergraph.shape[0]):
        print(i)
        for j in range(i + 1, hypergraph.shape[0], 1):
            m = []
            for lag in k_circular:
                time_serie_lagged = np.roll(time_series[:, hypergraph[j, :]], lag)
                m.append(dcor.u_distance_correlation_sqr(time_series[:, hypergraph[i, :]], time_serie_lagged))

            correlation_matrix[i, j] = max(m)
            correlation_matrix[j, i] = correlation_matrix[i, j]

    if savefigure is not None:
        figure = plt.figure(figsize=(6, 6))
        plotting.plot_matrix(correlation_matrix, figure=figure, vmax=1., vmin=0.)
        figure.savefig(savefigure, dpi=200)
        plt.close(figure)

    return correlation_matrix
Exemple #3
0
    def _compute_correlation(self, metrics_vals_1: pd.Series,
                             metrics_vals_2: pd.Series, lag: int):

        return dcor.u_distance_correlation_sqr(
            metrics_vals_1.astype(float),
            metrics_vals_2.shift(lag).fillna(0).astype(float),
            exponent=self.exponent)
Exemple #4
0
    def test_u_statistic(self):
        """Test that the fast and naive algorithms for unbiased dcor match"""
        for seed in range(5):

            random_state = np.random.RandomState(seed)

            for i in range(4, self.test_max_size + 1):
                arr1 = random_state.rand(i, 1)
                arr2 = random_state.rand(i, 1)

                u_stat = dcor.u_distance_correlation_sqr(
                    arr1, arr2, method='naive')
                for method in dcor.DistanceCovarianceMethod:
                    with self.subTest(method=method):
                        u_stat2 = dcor.u_distance_correlation_sqr(
                            arr1, arr2, method=method)

                        self.assertAlmostEqual(u_stat, u_stat2)
Exemple #5
0
    def get_dcorr(self, is_train=True):
        if is_train:
            settype = "train"
            loader = self.trainloader
        else:
            settype = "val"
            loader = self.valloader

        # Get alphas
        state = self._get_scalars_per_tau(loader, settype)
        preds = np.array(state["preds"])
        ntaus = preds.shape[0]

        metas = state["metas"]
        X = preds.T
        Y = metas["age"].reshape(-1, 1)

        # print("Dcorr {}: {}".format(settype, utils.dcor.dcor(X, Y)))
        u_dcor_sqr = dcor.u_distance_correlation_sqr(X, Y)
        u_dcor = math.sqrt(u_dcor_sqr)
        # print(dcor.u_distance_correlation_sqr(X, Y))
        dcor_scores = np.array([u_dcor_sqr, u_dcor])
        scores_dir = os.path.join(self.fig_dir, "scores", "dcorr")
        utils.fs.create_dir(scores_dir)
        score_fp = os.path.join(
            scores_dir, "{}_{}_{}_dcorr".format(self.agent.config.dataset,
                                                settype, self.agent.k))
        np.savetxt(score_fp, dcor_scores, delimiter=",")

        if self.agent.k == 4:
            u_dcor_scores = []
            u_dcor_sqr_scores = []
            for i in range(5):
                fold_score_fp = os.path.join(
                    scores_dir,
                    "{}_{}_{}_dcorr".format(self.agent.config.dataset, settype,
                                            i))
                scores = np.loadtxt(fold_score_fp, delimiter=",")
                u_dcor_sqr_scores.append(scores[0])
                u_dcor_scores.append(scores[1])
            u_dcor_scores = np.array(u_dcor_scores)
            u_dcor_sqr_scores = np.array(u_dcor_sqr_scores)
            mean_dcorr = u_dcor_scores.mean()
            std_dcorr = u_dcor_scores.std()
            mean_dcorr_sqr = u_dcor_sqr_scores.mean()
            std_dcorr_sqr = u_dcor_sqr_scores.std()
            u_dcor_scores = np.append(u_dcor_scores, [mean_dcorr, std_dcorr])
            u_dcor_sqr_scores = np.append(u_dcor_sqr_scores,
                                          [mean_dcorr_sqr, std_dcorr_sqr])
            # combine to one array for saving
            stacked_scores = np.stack((u_dcor_scores, u_dcor_sqr_scores))
            summary_score_fp = os.path.join(
                scores_dir, "{}_{}_{}_dcorr".format(self.agent.config.dataset,
                                                    settype, "summary"))
            np.savetxt(summary_score_fp, stacked_scores, delimiter=",")
def compute_functional_connectivity(time_courses, metric='pearson'):
    from dcor import u_distance_correlation_sqr
    from scipy.stats import pearsonr
    fc_matrix = np.zeros((len(time_courses), len(time_courses)))

    for i in range(len(time_courses)):
        for j in range(i, len(time_courses)):
            if metric == 'pearson':
                fc_matrix[i, j] = pearsonr(time_courses[i], time_courses[j])
            elif metric == 'distance':
                fc_matrix[i, j] = u_distance_correlation_sqr(
                    time_courses[i], time_courses[j])

    return fc_matrix
    def _run_interface(self, runtime):
        import numpy as np
        from nilearn import plotting
        import matplotlib.pyplot as plt
        import dcor
        from datetime import datetime
        from utils.utils import absmax

        hypergraph = np.loadtxt(self.inputs.hypergraph_path, delimiter=',')
        time_series = np.loadtxt(self.inputs.time_series_path, delimiter=',')

        k_circular = [0] if self.inputs.lag == 0 else range(-1 * self.inputs.lag, self.inputs.lag + 1, 1)

        correlation_matrix = np.zeros(hypergraph.shape)

        threshold = 0.3

        hypergraph[np.where(hypergraph > threshold)] = 1
        hypergraph[np.where(hypergraph != 1)] = 0
        hypergraph = hypergraph.astype(bool)

        then = datetime.now()

        for i in range(hypergraph.shape[0]):
            print(i)
            for j in range(i + 1, hypergraph.shape[0], 1):
                correlation_values_from_laggeds = []
                for lag in k_circular:
                    correlation_values_from_laggeds.append(
                        dcor.u_distance_correlation_sqr(time_series[:, hypergraph[i, :]],
                                                        np.roll(time_series[:, hypergraph[j, :]], lag)))

                    # correlation_matrix[i, j] = dcor.u_distance_correlation_sqr(time_series[:, hypergraph[i, :]],
                    #                                                           time_series[:, hypergraph[j, :]])
                correlation_matrix[i, j] = np.max(correlation_values_from_laggeds)
                # correlation_matrix[j, i] = correlation_matrix[i, j]

        figure = plt.figure(figsize=(6, 6))
        plotting.plot_matrix(correlation_matrix, figure=figure, vmax=1., vmin=0.)
        figure.savefig(self.inputs.correlation_matrix_plot_out_file, dpi=300)

        np.savetxt(self.inputs.correlation_matrix_out_file, correlation_matrix, delimiter=',', fmt='%10.2f')

        print('Total time: ', (datetime.now() - then).total_seconds())
        return runtime
Exemple #8
0
def compute_faster_correlation_matrix_from_hypergraph(hypergraph, time_series, savefigure=None):
    hypergraph_shape = hypergraph.shape
    correlation_matrix = np.zeros(hypergraph_shape)
    hypergraph = hypergraph.astype(bool)

    for i in range(hypergraph_shape[0]):
        print(i)
        for j in range(i + 1, hypergraph_shape[0], 1):
            correlation_matrix[i, j] = dcor.u_distance_correlation_sqr(time_series[:, hypergraph[i, :]],
                                                                 time_series[:, hypergraph[j, :]])
            #correlation_matrix[j, i] = correlation_matrix[i, j]

    if savefigure is not None:
        figure = plt.figure(figsize=(6, 6))
        plotting.plot_matrix(correlation_matrix, figure=figure, vmax=1., vmin=0.)
        figure.savefig(savefigure, dpi=200)
        plt.close(figure)

    return correlation_matrix
Exemple #9
0
#X = numpy.array([numpy.linspace(-1, 1, N) for _ in range(D)]).T
X = numpy.array([numpy.random.uniform(-1, 1, N) for _ in range(D)]).T
TWO_D = 2 * numpy.array(range(D))
Y = numpy.matmul(numpy.multiply(X, X), TWO_D)
# ---

# --- Transform data
M = numpy.array([numpy.random.uniform(-10, 10, D) for _ in range(D)])
N = numpy.array([numpy.random.uniform(-10, 10, N) for _ in range(D)]).T
X_TRANS1 = numpy.matmul(X, M)
X_TRANS2 = numpy.matmul(X, M) + N

print("Distance correlation:")
print(dcor.distance_correlation(Y, X))
print("Unbiased dcor:")
print(numpy.sqrt(dcor.u_distance_correlation_sqr(Y, X)))

#for _ in range(10000):
#    AIDC(X, Y)
#    dcor.distance_correlation_af_inv(Y, X)
#print("done")
#sys.exit()

print("AIDC original X:")
print(AIDC(X, Y))
print("AIDC built-in X:")
print(dcor.distance_correlation_af_inv(Y, X))
print("AIDC X = M*X:")
print(AIDC(X_TRANS1, Y))
print(dcor.distance_correlation_af_inv(Y, X_TRANS1))
print("AIDC X = M*X + N:")
Exemple #10
0
    def train(self,
              epochs,
              training,
              testing,
              testing_raw,
              batch_size=64,
              fold=0):
        [train_data_aug, train_dx_aug, train_age_aug, train_sex_aug] = training
        [test_data_aug, test_dx_aug, test_age_aug, test_sex_aug] = testing
        [test_data, test_dx, test_age, test_sex] = testing_raw

        test_data_aug_flip = np.flip(test_data_aug, 1)
        test_data_flip = np.flip(test_data, 1)

        idx_perm = np.random.permutation(int(train_data_aug.shape[0] / 2))

        dc_age = np.zeros((int(epochs / 10) + 1, ))
        min_dc = 0
        for epoch in range(epochs):

            ## Turn on to LR decay manually
            # if epoch % 200 == 0:
            #    self.lr = self.lr * 0.75
            #    optimizer = Adam(self.lr)
            #    self.workflow.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
            #    self.distiller.compile(loss=correlation_coefficient_loss, optimizer=optimizer)
            #    self.regressor.compile(loss='mse', optimizer=optimizer)

            # Select a random batch of images

            idx_perm = np.random.permutation(int(train_data_aug.shape[0] / 2))
            ctrl_idx = idx_perm[:int(batch_size)]
            idx_perm = np.random.permutation(int(train_data_aug.shape[0] / 2))
            idx = idx_perm[:int(batch_size / 2)]
            idx = np.concatenate((idx, idx + int(train_data_aug.shape[0] / 2)))

            training_feature_batch = train_data_aug[idx]
            dx_batch = train_dx_aug[idx]
            age_batch = train_age_aug[idx]

            training_feature_ctrl_batch = train_data_aug[ctrl_idx]
            age_ctrl_batch = train_age_aug[ctrl_idx]

            # ---------------------
            #  Train regressor (cf predictor)
            # ---------------------

            encoded_feature_ctrl_batch = self.encoder.predict(
                training_feature_ctrl_batch[:, :32, :, :])
            r_loss = self.regressor.train_on_batch(encoded_feature_ctrl_batch,
                                                   age_ctrl_batch)

            # ---------------------
            #  Train Disstiller
            # ---------------------

            g_loss = self.distiller.train_on_batch(
                training_feature_ctrl_batch[:, :32, :, :], age_ctrl_batch)

            # ---------------------
            #  Train Encoder & Classifier
            # ---------------------

            c_loss = self.workflow.train_on_batch(
                training_feature_batch[:, :32, :, :], dx_batch)

            # ---------------------
            #  flip & re-do everything
            # ---------------------

            training_feature_batch = np.flip(training_feature_batch, 1)
            training_feature_ctrl_batch = np.flip(training_feature_ctrl_batch,
                                                  1)

            encoded_feature_ctrl_batch = self.encoder.predict(
                training_feature_ctrl_batch[:, :32:, :])
            r_loss = self.regressor.train_on_batch(encoded_feature_ctrl_batch,
                                                   age_ctrl_batch)
            g_loss = self.distiller.train_on_batch(
                training_feature_ctrl_batch[:, :32, :, :], age_ctrl_batch)
            c_loss = self.workflow.train_on_batch(
                training_feature_batch[:, :32, :, :], dx_batch)

            # Plot the progress
            if epoch % 50 == 0:
                c_loss_test_1 = self.workflow.evaluate(
                    test_data_aug[:, :32, :, :],
                    test_dx_aug,
                    verbose=0,
                    batch_size=batch_size)
                c_loss_test_2 = self.workflow.evaluate(
                    test_data_aug_flip[:, :32, :, :],
                    test_dx_aug,
                    verbose=0,
                    batch_size=batch_size)

                # feature dist corr
                features_dense = self.encoder.predict(
                    train_data_aug[train_dx_aug == 0, :32, :, :],
                    batch_size=batch_size)
                dc_age[int(epoch / 10)] = dcor.u_distance_correlation_sqr(
                    features_dense, train_age_aug[train_dx_aug == 0])
                print("%d [Acc: %f,  Test Acc: %f %f,  dc: %f]" %
                      (epoch, c_loss[1], c_loss_test_1[1], c_loss_test_2[1],
                       dc_age[int(epoch / 10)]))
                sys.stdout.flush()

                self.classifier.save_weights("res_cf_5cv/classifier.h5")
                self.encoder.save_weights("res_cf_5cv/encoder.h5")
                self.workflow.save_weights("res_cf_5cv/workflow.h5")

                ## Turn on to save all intermediate features for posthoc MI computation
                #features_dense = self.encoder.predict(test_data[:,:32,:,:],  batch_size = 64)
                #filename = 'res_cf/features_'+str(fold)+'.txt'
                #np.savetxt(filename,features_dense)
                #score = self.classifier.predict(features_dense,  batch_size = 64)
                #filename = 'res_cf/scores_'+str(fold)+'_'+str(epoch)+'.txt'
                #np.savetxt(filename,score)

                #features_dense = self.encoder.predict(test_data_flip[:,:32,:,:],  batch_size = 64)
                #filename = 'res_cf/features_flip_'+str(fold)+'.txt'
                #np.savetxt(filename,features_dense)
                #score = self.classifier.predict(features_dense,  batch_size = 64)
                #filename = 'res_cf/scores_flip_'+str(fold)+'_'+str(epoch)+'.txt'
                #np.savetxt(filename,score)

                # save intermediate predictions
                prediction = self.workflow.predict(test_data[:, :32, :, :],
                                                   batch_size=64)
                filename = 'res_cf_5cv/prediction_' + str(fold) + '_' + str(
                    epoch) + '.txt'
                np.savetxt(filename, prediction)
                prediction = self.workflow.predict(
                    test_data_flip[:, :32, :, :], batch_size=64)
                filename = 'res_cf_5cv/prediction_flip_' + str(
                    fold) + '_' + str(epoch) + '.txt'
                np.savetxt(filename, prediction)

                # save ground-truth
                filename = 'res_cf_5cv/dx_' + str(fold) + '.txt'
                np.savetxt(filename, test_dx)
                filename = 'res_cf_5cv/cf_' + str(fold) + '.txt'
                np.savetxt(filename, test_age)
Exemple #11
0
    def test_distance_correlation_comparison(self):
        """
        Compare all implementations of the distance covariance and correlation.
        """
        arr1 = np.array(((1.,), (2.,), (3.,), (4.,), (5.,), (6.,)))
        arr2 = np.array(((1.,), (7.,), (5.,), (5.,), (6.,), (2.,)))

        for method in dcor.DistanceCovarianceMethod:
            with self.subTest(method=method):

                compile_modes = [dcor.CompileMode.AUTO,
                                 dcor.CompileMode.COMPILE_CPU,
                                 dcor.CompileMode.NO_COMPILE]

                if method is not dcor.DistanceCovarianceMethod.NAIVE:
                    compile_modes += [dcor.CompileMode.COMPILE_CPU]

                for compile_mode in compile_modes:
                    with self.subTest(compile_mode=compile_mode):

                        # Unbiased versions

                        covariance = dcor.u_distance_covariance_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, -0.88889, places=5)

                        correlation = dcor.u_distance_correlation_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, -0.41613, places=5)

                        covariance = dcor.u_distance_covariance_sqr(
                            arr1, arr1,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, 1.55556, places=5)

                        correlation = dcor.u_distance_correlation_sqr(
                            arr1, arr1,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, 1, places=5)

                        covariance = dcor.u_distance_covariance_sqr(
                            arr2, arr2,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, 2.93333, places=5)

                        correlation = dcor.u_distance_correlation_sqr(
                            arr2, arr2,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, 1, places=5)

                        stats = dcor.u_distance_stats_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        np.testing.assert_allclose(
                            stats, (-0.88889, -0.41613, 1.55556, 2.93333),
                            rtol=1e-4)

                        # Biased

                        covariance = dcor.distance_covariance_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, 0.68519, places=5)

                        correlation = dcor.distance_correlation_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, 0.30661, places=5)

                        covariance = dcor.distance_covariance_sqr(
                            arr1, arr1,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, 1.70679, places=5)

                        correlation = dcor.distance_correlation_sqr(
                            arr1, arr1,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, 1, places=5)

                        covariance = dcor.distance_covariance_sqr(
                            arr2, arr2,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(covariance, 2.92593, places=5)

                        correlation = dcor.distance_correlation_sqr(
                            arr2, arr2,  method=method,
                            compile_mode=compile_mode)
                        self.assertAlmostEqual(correlation, 1, places=5)

                        stats = dcor.distance_stats_sqr(
                            arr1, arr2, method=method,
                            compile_mode=compile_mode)
                        np.testing.assert_allclose(
                            stats, (0.68519, 0.30661, 1.70679, 2.92593),
                            rtol=1e-4)
Exemple #12
0
def correlate(x, y):
    # X = pdist(x, "euclidean")
    # Y = pdist(y, "euclidean")
    # return scipy.stats.pearsonr(X.flatten(), Y.flatten())[0]
    return np.sqrt(dcor.u_distance_correlation_sqr(x, y))
Exemple #13
0
def run_experiment(mdn,
                   run_name_base,
                   batch_size,
                   learning_rate,
                   run,
                   x,
                   labels,
                   cf,
                   x_val,
                   labels_val,
                   cf_val,
                   epochs=5000,
                   N=1000):

    trainset_size = 2 * N
    experiment_name = os.path.join(run_name_base,
                                   'batch_size' + str(batch_size))
    run_name = os.path.join(experiment_name, 'run' + str(run))
    log_file = os.path.join(experiment_name, 'metrics.txt')
    skmetrics_file = os.path.join(experiment_name, 'skmetrics.txt')
    run_log_file = os.path.join(run_name, 'metrics.txt')
    if not os.path.exists(run_name):
        os.makedirs(run_name)
    with open(run_log_file, 'w') as f:
        f.write('acc' + "\t" + 'dc0_val' + '\t' + 'dc1_val' + '\t' + 'loss' +
                "\n")
    print("run name:", run_name)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Calculate confounder kernel, the precalculated kernel (X^TX)^-1 for MDN based on the vector X
    # of confounders. Only needs to be calculated once before training.
    if mdn == 'Conv' or mdn == 'Linear':
        X = np.zeros((N * 2, 3))
        X[:, 0] = labels
        X[:, 1] = cf
        X[:, 2] = np.ones((N * 2, ))
        XTX = np.transpose(X).dot(X)
        kernel = np.linalg.inv(XTX)
        cf_kernel = nn.Parameter(torch.tensor(kernel).float().to(device),
                                 requires_grad=False)

    # Create model
    if mdn == 'Baseline':
        model = BaselineNet()
    elif mdn == 'Linear':
        model = MDN_Linear(2 * N, batch_size, cf_kernel)
    elif mdn == 'Conv':
        model = MDN_Conv(2 * N, batch_size, cf_kernel)
    else:
        print('mdn type not supported')
        return

    model.to(device)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              'min',
                                                              min_lr=1e-6,
                                                              factor=0.5)
    iterations = 2 * N // batch_size
    print(model)

    # Make dataloaders
    print("Making dataloaders...")
    train_set = SyntheticDataset(x, labels, cf)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True)
    val_set = SyntheticDataset(x_val, labels_val, cf_val)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             pin_memory=True)

    # Run training
    acc_history = []
    acc_history_val = []
    dc0s_val = []
    dc1s_val = []
    dcors = []
    losses = []
    losses_val = []
    patience = 0  # number of epochs where val loss doesn't decrease
    min_loss = float('inf')
    run_dir = os.path.join('plot_data', run_name)
    if not os.path.exists(run_dir):
        os.makedirs(run_dir)

    for e in range(epochs):
        cfs_val = []
        feature_val = []
        epoch_acc = 0
        epoch_acc_val = 0
        epoch_loss = 0
        epoch_loss_val = 0
        pred_vals = []
        target_vals = []

        # Training pass
        model = model.train()
        for i, sample_batched in enumerate(train_loader):
            data = sample_batched['image'].float()
            target = sample_batched['label'].float()
            cf_batch = sample_batched['cfs'].float()
            data, target = data.cuda(), target.cuda()

            # Add confounder input feature (cfs) to model. cfs are stored in the dataset and need be set
            # during training for each batch.
            X_batch = np.zeros((batch_size, 3))
            X_batch[:, 0] = target.cpu().detach().numpy()
            X_batch[:, 1] = cf_batch.cpu().detach().numpy()
            X_batch[:, 2] = np.ones((batch_size, ))
            with torch.no_grad():
                model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device),
                                         requires_grad=False)

            # Forward pass
            optimizer.zero_grad()
            y_pred, fc = model(data)
            loss = criterion(y_pred, target.unsqueeze(1))
            acc = binary_acc(y_pred, target.unsqueeze(1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_acc += acc.item()

        # Validation pass
        model = model.eval()
        for i, sample_batched in enumerate(val_loader):
            data = sample_batched['image'].float()
            target = sample_batched['label'].float()
            cf_batch = sample_batched['cfs'].float()
            data, target = data.cuda(), target.cuda()

            X_batch = np.zeros((batch_size, 3))
            X_batch[:, 0] = target.cpu().detach().numpy()
            X_batch[:, 1] = cf_batch.cpu().detach().numpy()
            X_batch[:, 2] = np.ones((batch_size, ))

            with torch.no_grad():
                model.cfs = nn.Parameter(torch.Tensor(X_batch).to(device),
                                         requires_grad=False)
                y_pred, fc = model(data)
                loss = criterion(y_pred, target.unsqueeze(1))
                acc = binary_acc(y_pred, target.unsqueeze(1))
                epoch_loss_val += loss.item()
                epoch_acc_val += acc.item()

            # Save learned features
            feature_val.append(fc)
            cfs_val.append(cf_batch)
            target_vals.append(target.cpu())
            pred_vals.append(y_pred.cpu())

        # Calculate distance correlation between confounders and learned features
        epoch_targets = np.concatenate(target_vals, axis=0)
        epoch_preds = np.concatenate(pred_vals, axis=0)
        i0_val = np.where(epoch_targets == 0)[0]
        i1_val = np.where(epoch_targets == 1)[0]
        epoch_layer = np.concatenate(feature_val, axis=0)
        epoch_cf = np.concatenate(cfs_val, axis=0)
        dc0_val = dcor.u_distance_correlation_sqr(epoch_layer[i0_val],
                                                  epoch_cf[i0_val])
        dc1_val = dcor.u_distance_correlation_sqr(epoch_layer[i1_val],
                                                  epoch_cf[i1_val])
        print("correlations for feature 0:", dc0_val)
        print("correlations for feature 1:", dc1_val)
        dc0s_val.append(dc0_val)
        dc1s_val.append(dc1_val)

        curr_acc = epoch_acc / iterations
        acc_history.append(curr_acc)
        losses.append(epoch_loss)
        curr_acc_val = epoch_acc_val / iterations
        acc_history_val.append(curr_acc_val)
        losses_val.append(epoch_loss_val)

        print('learning rate:', optimizer.param_groups[0]['lr'])
        lr_scheduler.step(epoch_loss_val)

        # Save model with lowest loss
        if epoch_loss_val + 0.001 < min_loss:
            print("Best loss so far! Saving model...")
            min_loss = epoch_loss_val
            #if run <= 5:
            #torch.save(model, os.path.join('plot_data', run_name, 'best_model.pth'))
            patience = 0

        print(
            f'Train: Epoch {e+0:03}: | Loss: {epoch_loss/iterations:.5f} | Acc: {epoch_acc/iterations:.3f}'
        )
        print(
            f'Val: Epoch {e+0:03}: | Loss: {epoch_loss_val/iterations:.5f} | Acc: {epoch_acc_val/iterations:.3f}'
        )
        np.save(os.path.join(run_name, 'd' + str(trainset_size) + '.npy'),
                acc_history)
        np.save(
            os.path.join(run_name, 'loss_d' + str(trainset_size)) + '.npy',
            losses)
        np.save(
            os.path.join(run_name, 'val_loss_d' + str(trainset_size)) + '.npy',
            losses_val)
        np.save(
            os.path.join(run_name, 'val_acc_d' + str(trainset_size) + '.npy'),
            acc_history_val)
        np.save(
            os.path.join(run_name, 'val_dc0s_d' + str(trainset_size) + '.npy'),
            dc0s_val)
        np.save(
            os.path.join(run_name, 'val_dc1s_d' + str(trainset_size) + '.npy'),
            dc1s_val)

        with open(run_log_file, 'a') as f:
            f.write(
                str(curr_acc_val) + '\t' + str(dc0_val) + '\t' + str(dc1_val) +
                '\t' + str(epoch_loss_val) + '\n')

        print("patience:", patience)
        patience += 1
        if patience > 200:
            print("out of patience")
            break

    y_test, y_pred_list = test(model, mdn, batch_size, run=run, N=N)

    with open(log_file, 'a') as f:
        f.write(
            str(run) + '\t' + str(curr_acc_val) + '\t' + str(dc0_val) + '\t' +
            str(dc1_val) + '\t' + str(epoch_loss_val) + "\n")

    with open(skmetrics_file, 'a') as f:
        f.write("Report for run " + str(run))
        f.write(classification_report(y_test, y_pred_list, digits=3) + '\n')