Ejemplo n.º 1
0
def misgan_impute(args,
                  data_gen,
                  mask_gen,
                  imputer,
                  data_critic,
                  mask_critic,
                  impu_critic,
                  data,
                  output_dir,
                  checkpoint=None):
    n_critic = args.n_critic
    gp_lambda = args.gp_lambda
    batch_size = args.batch_size
    nz = args.n_latent
    epochs = args.epoch
    plot_interval = args.plot_interval
    save_model_interval = args.save_interval
    alpha = args.alpha
    beta = args.beta
    gamma = args.gamma
    tau = args.tau
    update_all_networks = not args.imputeronly

    gen_data_dir = mkdir(output_dir / 'img')
    gen_mask_dir = mkdir(output_dir / 'mask')
    impute_dir = mkdir(output_dir / 'impute')
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    data_loader = DataLoader(data,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             num_workers=args.workers)
    n_batch = len(data_loader)
    data_shape = data[0][0].shape

    data_noise = torch.FloatTensor(batch_size, nz).to(device)
    mask_noise = torch.FloatTensor(batch_size, nz).to(device)
    impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device)

    # Interpolation coefficient
    eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)

    # For computing gradient penalty
    ones = torch.ones(batch_size).to(device)

    lrate = 1e-4
    imputer_lrate = 2e-4
    data_gen_optimizer = optim.Adam(data_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    mask_gen_optimizer = optim.Adam(mask_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    imputer_optimizer = optim.Adam(imputer.parameters(),
                                   lr=imputer_lrate,
                                   betas=(.5, .9))

    data_critic_optimizer = optim.Adam(data_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    mask_critic_optimizer = optim.Adam(mask_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    impu_critic_optimizer = optim.Adam(impu_critic.parameters(),
                                       lr=imputer_lrate,
                                       betas=(.5, .9))

    update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_impu_critic = CriticUpdater(impu_critic, impu_critic_optimizer, eps,
                                       ones, gp_lambda)

    start_epoch = 0
    critic_updates = 0
    log = defaultdict(list)

    if args.resume:
        data_gen.load_state_dict(checkpoint['data_gen'])
        mask_gen.load_state_dict(checkpoint['mask_gen'])
        imputer.load_state_dict(checkpoint['imputer'])
        data_critic.load_state_dict(checkpoint['data_critic'])
        mask_critic.load_state_dict(checkpoint['mask_critic'])
        impu_critic.load_state_dict(checkpoint['impu_critic'])
        data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
        mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
        imputer_optimizer.load_state_dict(checkpoint['imputer_opt'])
        data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
        mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
        impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt'])
        start_epoch = checkpoint['epoch']
        critic_updates = checkpoint['critic_updates']
        log = checkpoint['log']
    elif args.pretrain:
        pretrain = torch.load(args.pretrain, map_location='cpu')
        data_gen.load_state_dict(pretrain['data_gen'])
        mask_gen.load_state_dict(pretrain['mask_gen'])
        data_critic.load_state_dict(pretrain['data_critic'])
        mask_critic.load_state_dict(pretrain['mask_critic'])
        if 'imputer' in pretrain:
            imputer.load_state_dict(pretrain['imputer'])
            impu_critic.load_state_dict(pretrain['impu_critic'])

    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)

    def save_model(path, epoch, critic_updates=0):
        torch.save(
            {
                'data_gen': data_gen.state_dict(),
                'mask_gen': mask_gen.state_dict(),
                'imputer': imputer.state_dict(),
                'data_critic': data_critic.state_dict(),
                'mask_critic': mask_critic.state_dict(),
                'impu_critic': impu_critic.state_dict(),
                'data_gen_opt': data_gen_optimizer.state_dict(),
                'mask_gen_opt': mask_gen_optimizer.state_dict(),
                'imputer_opt': imputer_optimizer.state_dict(),
                'data_critic_opt': data_critic_optimizer.state_dict(),
                'mask_critic_opt': mask_critic_optimizer.state_dict(),
                'impu_critic_opt': impu_critic_optimizer.state_dict(),
                'epoch': epoch + 1,
                'critic_updates': critic_updates,
                'log': log,
                'args': args,
            }, str(path))

    sns.set()
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0
        for real_data, real_mask, _, index in data_loader:
            # Assume real_data and real_mask have the same number of channels.
            # Could be modified to handle multi-channel images and
            # single-channel masks.
            real_mask = real_mask.float()[:, None]

            real_data = real_data.to(device)
            real_mask = real_mask.to(device)

            masked_real_data = mask_data(real_data, real_mask, tau)

            # Update discriminators' parameters
            data_noise.normal_()
            fake_data = data_gen(data_noise)

            impu_noise.uniform_()
            imputed_data = imputer(real_data, real_mask, impu_noise)
            masked_imputed_data = mask_data(real_data, real_mask, imputed_data)

            if update_all_networks:
                mask_noise.normal_()
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)
                update_data_critic(masked_real_data, masked_fake_data)
                update_mask_critic(real_mask, fake_mask)

                sum_data_loss += update_data_critic.loss_value
                sum_mask_loss += update_mask_critic.loss_value

            update_impu_critic(fake_data, masked_imputed_data)
            sum_impu_loss += update_impu_critic.loss_value

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                if update_all_networks:
                    for p in data_critic.parameters():
                        p.requires_grad_(False)
                    for p in mask_critic.parameters():
                        p.requires_grad_(False)
                for p in impu_critic.parameters():
                    p.requires_grad_(False)

                impu_noise.uniform_()
                imputed_data = imputer(real_data, real_mask, impu_noise)
                masked_imputed_data = mask_data(real_data, real_mask,
                                                imputed_data)
                impu_loss = -impu_critic(masked_imputed_data).mean()

                if update_all_networks:
                    data_noise.normal_()
                    fake_data = data_gen(data_noise)
                    mask_noise.normal_()
                    fake_mask = mask_gen(mask_noise)
                    masked_fake_data = mask_data(fake_data, fake_mask, tau)
                    data_loss = -data_critic(masked_fake_data).mean()
                    mask_loss = -mask_critic(fake_mask).mean()

                    mask_gen.zero_grad()
                    (mask_loss + data_loss * alpha).backward(retain_graph=True)
                    mask_gen_optimizer.step()

                    data_noise.normal_()
                    fake_data = data_gen(data_noise)
                    mask_noise.normal_()
                    fake_mask = mask_gen(mask_noise)
                    masked_fake_data = mask_data(fake_data, fake_mask, tau)
                    data_loss = -data_critic(masked_fake_data).mean()

                    data_gen.zero_grad()
                    (data_loss + impu_loss * beta).backward(retain_graph=True)
                    data_gen_optimizer.step()

                imputer.zero_grad()
                if gamma > 0:
                    imputer_mismatch_loss = mask_norm(
                        (imputed_data - real_data)**2, real_mask)
                    (impu_loss + imputer_mismatch_loss * gamma).backward()
                else:
                    impu_loss.backward()
                imputer_optimizer.step()

                if update_all_networks:
                    for p in data_critic.parameters():
                        p.requires_grad_(True)
                    for p in mask_critic.parameters():
                        p.requires_grad_(True)
                for p in impu_critic.parameters():
                    p.requires_grad_(True)

        if update_all_networks:
            mean_data_loss = sum_data_loss / n_batch
            mean_mask_loss = sum_mask_loss / n_batch
            log['data loss', 'data_loss'].append(mean_data_loss)
            log['mask loss', 'mask_loss'].append(mean_mask_loss)
        mean_impu_loss = sum_impu_loss / n_batch
        log['imputer loss', 'impu_loss'].append(mean_impu_loss)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if update_all_networks:
                print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format(
                    epoch, mean_data_loss, mean_mask_loss, mean_impu_loss))
            else:
                print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss))

            filename = f'{epoch:04d}.png'
            with torch.no_grad():
                data_gen.eval()
                mask_gen.eval()
                imputer.eval()

                data_noise.normal_()
                mask_noise.normal_()

                data_samples = data_gen(data_noise)
                plot_samples(data_samples, str(gen_data_dir / filename))

                mask_samples = mask_gen(mask_noise)
                plot_samples(mask_samples, str(gen_mask_dir / filename))

                # Plot imputation results
                impu_noise.uniform_()
                imputed_data = imputer(real_data, real_mask, impu_noise)
                imputed_data = mask_data(real_data, real_mask, imputed_data)
                if hasattr(data, 'mask_info'):
                    bbox = [data.mask_info[idx] for idx in index]
                else:
                    bbox = None
                plot_grid(imputed_data,
                          bbox,
                          gap=2,
                          save_file=str(impute_dir / filename))

                data_gen.train()
                mask_gen.train()
                imputer.train()

        for (name, shortname), trace in log.items():
            fig, ax = plt.subplots(figsize=(6, 4))
            ax.plot(trace)
            ax.set_ylabel(name)
            ax.set_xlabel('epoch')
            fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
            plt.close(fig)

        if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0:
            save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)

        epoch_end = time.time()
        time_elapsed = epoch_end - start
        epoch_time = epoch_end - epoch_start
        epoch_start = epoch_end
        with (log_dir / 'epoch-time.txt').open('a') as f:
            print(epoch, epoch_time, time_elapsed, file=f)
        save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)

    print(output_dir)
            param_str = "SVM"

        # --------- CVPR model ----------
        elif model_type in ["tCNN", "ED-TCN", "DilatedTCN", "TDNN", "LSTM"]:
            # Go from y_t = {1...C} to one-hot vector (e.g. y_t = [0, 0, 1, 0])
            Y_train = [np_utils.to_categorical(y, n_classes) for y in y_train]
            Y_test = [np_utils.to_categorical(y, n_classes) for y in y_test]

            # In order process batches simultaneously all data needs to be of the same length
            # So make all same length and mask out the ends of each.
            n_layers = len(n_nodes)
            max_len = max(np.max(train_lengths), np.max(test_lengths))
            max_len = int(np.ceil(max_len / (2**n_layers)))*2**n_layers
            print("Max length:", max_len)

            X_train_m, Y_train_, M_train = utils.mask_data(X_train, Y_train, max_len, mask_value=-1)
            X_test_m, Y_test_, M_test = utils.mask_data(X_test, Y_test, max_len, mask_value=-1)

            if model_type == "tCNN":
                model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, 
                                                    max_len, causal=causal, return_param_str=True)
            elif model_type == "ED-TCN":
                model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, 
                                        activation='norm_relu', return_param_str=True) 
                # model, param_str = tf_models.ED_TCN_atrous(n_nodes, conv, n_classes, n_feat, max_len, 
                                    # causal=causal, activation='norm_relu', return_param_str=True)                 
            elif model_type == "TDNN":
                model, param_str = tf_models.TimeDelayNeuralNetwork(n_nodes, conv, n_classes, n_feat, max_len, 
                                   causal=causal, activation='tanh', return_param_str=True)
            elif model_type == "DilatedTCN":
                model, param_str = tf_models.Dilated_TCN(n_feat, n_classes, n_nodes[0], L, B, max_len=max_len, 
Ejemplo n.º 3
0
n_classes = len(actions_dict)

train_lengths = [x.shape[0] for x in X_train]
test_lengths = [x.shape[0] for x in X_test]
n_train = len(X_train)
n_test = len(X_test)

n_feat = 400
n_layers = len(n_nodes)
max_len = max(np.max(train_lengths), np.max(test_lengths))
max_len = int(np.ceil(max_len / (2**n_layers)))*2**n_layers
print("Max length:", max_len)

# pdb.set_trace()

X_train_m, Y_train_, M_train = utils.mask_data(X_train, y_train, max_len, mask_value=0)
X_test_m, M_test = utils.mask_test_data(X_test, max_len, mask_value=0)

model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, 
                            activation='norm_relu', return_param_str=True)

#     model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, 
#                                         max_len, causal=causal, return_param_str=True)
# elif model_type == "ED-TCN":
#     model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, 
#                             activation='norm_relu', return_param_str=True) 
#     # model, param_str = tf_models.ED_TCN_atrous(n_nodes, conv, n_classes, n_feat, max_len, 
#                         # causal=causal, activation='norm_relu', return_param_str=True)                 
# elif model_type == "TDNN":
#     model, param_str = tf_models.TimeDelayNeuralNetwork(n_nodes, conv, n_classes, n_feat, max_len, 
#                        causal=causal, activation='tanh', return_param_str=True)
Ejemplo n.º 4
0
                    type=int,
                    default=10,
                    required=False,
                    help="Random seed")
args = parser.parse_args()

if __name__ == '__main__':

    print('Read and preprocess')
    X_train, X_test, y_train, y_test = read_preprocess(args.dataset)

    print('Train Model')
    model = train_model(args.dataset, X_train, y_train, _seed=args.seed)

    print('Mask data')
    X_test_masked, mask_data_filling = mask_data(X_test, X_train.shape[1],
                                                 args.seed)

    if args.method == 'TNBQ':
        print('The Next Best Question')
        K = int(args.parameters[0])
        res_lst = the_next_best_question(X_train, X_test_masked, y_test, model,
                                         X_train.shape[1], mask_data_filling,
                                         K)
        save_results(res_lst, "The Next Best Question", args.parameters,
                     args.seed, args.dataset.title())

    if args.method.lower() == 'global':
        print('Global Feature Importance Acquisition')
        res_lst = global_shap_baseline(X_train, X_test_masked, y_test, model,
                                       X_train.shape[1], mask_data_filling)
        save_results(res_lst, "Global Feature Importance Acquisition",
Ejemplo n.º 5
0
def misgan(args,
           data_gen,
           mask_gen,
           data_critic,
           mask_critic,
           data,
           output_dir,
           checkpoint=None):
    n_critic = args.n_critic
    gp_lambda = args.gp_lambda
    batch_size = args.batch_size
    nz = args.n_latent
    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval
    alpha = args.alpha
    tau = args.tau

    gen_data_dir = mkdir(output_dir / 'img')
    gen_mask_dir = mkdir(output_dir / 'mask')
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    data_loader = DataLoader(data,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True)
    n_batch = len(data_loader)

    data_noise = torch.FloatTensor(batch_size, nz).to(device)
    mask_noise = torch.FloatTensor(batch_size, nz).to(device)

    # Interpolation coefficient
    eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)

    # For computing gradient penalty
    ones = torch.ones(batch_size).to(device)

    lrate = 1e-4
    # lrate = 1e-5
    data_gen_optimizer = optim.Adam(data_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))
    mask_gen_optimizer = optim.Adam(mask_gen.parameters(),
                                    lr=lrate,
                                    betas=(.5, .9))

    data_critic_optimizer = optim.Adam(data_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))
    mask_critic_optimizer = optim.Adam(mask_critic.parameters(),
                                       lr=lrate,
                                       betas=(.5, .9))

    update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps,
                                       ones, gp_lambda)
    update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps,
                                       ones, gp_lambda)

    start_epoch = 0
    critic_updates = 0
    log = defaultdict(list)

    if checkpoint:
        data_gen.load_state_dict(checkpoint['data_gen'])
        mask_gen.load_state_dict(checkpoint['mask_gen'])
        data_critic.load_state_dict(checkpoint['data_critic'])
        mask_critic.load_state_dict(checkpoint['mask_critic'])
        data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
        mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
        data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
        mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
        start_epoch = checkpoint['epoch']
        critic_updates = checkpoint['critic_updates']
        log = checkpoint['log']

    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)

    def save_model(path, epoch, critic_updates=0):
        torch.save(
            {
                'data_gen': data_gen.state_dict(),
                'mask_gen': mask_gen.state_dict(),
                'data_critic': data_critic.state_dict(),
                'mask_critic': mask_critic.state_dict(),
                'data_gen_opt': data_gen_optimizer.state_dict(),
                'mask_gen_opt': mask_gen_optimizer.state_dict(),
                'data_critic_opt': data_critic_optimizer.state_dict(),
                'mask_critic_opt': mask_critic_optimizer.state_dict(),
                'epoch': epoch + 1,
                'critic_updates': critic_updates,
                'log': log,
                'args': args,
            }, str(path))

    sns.set()

    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        sum_data_loss, sum_mask_loss = 0, 0
        for real_data, real_mask, _, _ in data_loader:
            # Assume real_data and mask have the same number of channels.
            # Could be modified to handle multi-channel images and
            # single-channel masks.
            real_mask = real_mask.float()[:, None]

            real_data = real_data.to(device)
            real_mask = real_mask.to(device)

            masked_real_data = mask_data(real_data, real_mask, tau)

            # Update discriminators' parameters
            data_noise.normal_()
            mask_noise.normal_()

            fake_data = data_gen(data_noise)
            fake_mask = mask_gen(mask_noise)

            masked_fake_data = mask_data(fake_data, fake_mask, tau)

            update_data_critic(masked_real_data, masked_fake_data)
            update_mask_critic(real_mask, fake_mask)

            sum_data_loss += update_data_critic.loss_value
            sum_mask_loss += update_mask_critic.loss_value

            critic_updates += 1

            if critic_updates == n_critic:
                critic_updates = 0

                # Update generators' parameters
                for p in data_critic.parameters():
                    p.requires_grad_(False)
                for p in mask_critic.parameters():
                    p.requires_grad_(False)

                data_noise.normal_()
                mask_noise.normal_()

                fake_data = data_gen(data_noise)
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)

                data_loss = -data_critic(masked_fake_data).mean()
                data_gen.zero_grad()
                data_loss.backward()
                data_gen_optimizer.step()

                data_noise.normal_()
                mask_noise.normal_()

                fake_data = data_gen(data_noise)
                fake_mask = mask_gen(mask_noise)
                masked_fake_data = mask_data(fake_data, fake_mask, tau)

                data_loss = -data_critic(masked_fake_data).mean()
                mask_loss = -mask_critic(fake_mask).mean()
                mask_gen.zero_grad()
                (mask_loss + data_loss * alpha).backward()
                mask_gen_optimizer.step()

                for p in data_critic.parameters():
                    p.requires_grad_(True)
                for p in mask_critic.parameters():
                    p.requires_grad_(True)

        mean_data_loss = sum_data_loss / n_batch
        mean_mask_loss = sum_mask_loss / n_batch
        log['data loss', 'data_loss'].append(mean_data_loss)
        log['mask loss', 'mask_loss'].append(mean_mask_loss)

        for (name, shortname), trace in log.items():
            fig, ax = plt.subplots(figsize=(6, 4))
            ax.plot(trace)
            ax.set_ylabel(name)
            ax.set_xlabel('epoch')
            fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
            plt.close(fig)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}')

            filename = f'{epoch:04d}.png'

            data_gen.eval()
            mask_gen.eval()

            with torch.no_grad():
                data_noise.normal_()
                mask_noise.normal_()

                data_samples = data_gen(data_noise)
                plot_samples(data_samples, str(gen_data_dir / filename))

                mask_samples = mask_gen(mask_noise)
                plot_samples(mask_samples, str(gen_mask_dir / filename))

            data_gen.train()
            mask_gen.train()

        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)

        epoch_end = time.time()
        time_elapsed = epoch_end - start
        epoch_time = epoch_end - epoch_start
        epoch_start = epoch_end
        with (log_dir / 'time.txt').open('a') as f:
            print(epoch, epoch_time, time_elapsed, file=f)
        save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)

    print(output_dir)
Ejemplo n.º 6
0
def get_bold_for_condition(dir_input, num_runs, option_zscore=0):
    """" This function extracts the bold signal for three conditions. 
    option_zscore = 0 => no z-scoring
    option_zscore = 1 =>z-score the data
    
    Returns: bold values for all conditions for each run.
    A mean value for the entire run.
    """
    from utils import shift_timing, mask_data, scale_data
    #Initialize arrays
    stim_label = []
    bold_A = []
    bold_B = []
    bold_C = []
    bold_fix = []
    bold_mean_all = []
    TR_shift_size = 2  # Number of TRs to shift the extraction of the BOLD signal.

    maskdir = dir_input
    masks = ['ROI_Cool']

    ### Extract the BOLD Signal for the conditions A, B, C
    ###

    print("Processing Start ...")
    maskfile = (maskdir + "%s.nii.gz" % (masks[0]))
    mask = nib.load(maskfile)
    print("Loaded Mask")
    print(mask.shape)

    for run in range(1, num_runs + 1):
        epi_in = (dir_input + "lab1_r0%s.nii.gz" % (run))
        stim_label = np.load(dir_input + 'labels_r0%s.npy' % (run))

        # Haemodynamic shift
        label_TR_shifted = shift_timing(stim_label, TR_shift_size)

        # Get labels for conditions for A, B, C, and baseline fixation.
        A = np.squeeze(np.argwhere(label_TR_shifted == 1))
        B = np.squeeze(np.argwhere(label_TR_shifted == 2))
        C = np.squeeze(np.argwhere(label_TR_shifted == 3))

        fixation = np.squeeze(np.argwhere(label_TR_shifted == 0))
        epi_data = nib.load(epi_in)
        epi_mask_data = mask_data(epi_data, mask)

        if option_zscore == 1:
            epi_maskdata_zscore = scale_data(epi_mask_data)
            epi_mask_data = epi_maskdata_zscore

        if run == 1:
            bold_A = epi_mask_data[A]
            bold_B = epi_mask_data[B]
            bold_C = epi_mask_data[C]
            bold_fix = epi_mask_data[fixation]
            bold_data_all = epi_mask_data
        else:
            bold_A = np.vstack([bold_A, epi_mask_data[A]])
            bold_B = np.vstack([bold_B, epi_mask_data[B]])
            bold_C = np.vstack([bold_C, epi_mask_data[C]])
            bold_fix = np.vstack([bold_fix, epi_mask_data[fixation]])
            bold_data_all = np.vstack([bold_data_all, epi_mask_data])
        bold_mean_all.append(np.mean(epi_mask_data))
    print("Processing Completed")
    return bold_data_all, bold_mean_all, bold_A, bold_B, bold_C, bold_fix, label_TR_shifted
Ejemplo n.º 7
0
def Train(train_loader, model_name, optimizer, node_num, rho_max, epochs, lr,
          batch_size, tau, n_critic, lambda_sparse, lambda_mmd, h_loss, rho,
          alpha):

    #---------------network selection-------------------
    Gen = Linear_Generator(node_num)
    Disc = Discriminator(node_num, model_name)

    #-----------------optimizer-------------------------
    if optimizer == 'ADAM':
        Gen_optimizer = torch.optim.Adam(Gen.parameters(), lr=lr)
        Disc_optimizer = torch.optim.Adam(Disc.parameters(), lr=lr)
    elif optimizer == 'RMS':
        Gen_optimizer = torch.optim.RMSprop(Gen.parameters(), lr=lr)
        Disc_optimizer = torch.optim.RMSprop(Disc.parameters(), lr=lr)
    elif optimizer == 'LBFGS':
        Gen_optimizer = torch.optim.LBFGS(Gen.parameters(), lr=lr)
        Disc_optimizer = torch.optim.LBFGS(Disc.parameters(), lr=lr)
    elif optimizer == 'SGD':
        Gen_optimizer = torch.optim.SGD(Gen.parameters(), lr=lr)
        Disc_optimizer = torch.optim.SGD(Disc.parameters(), lr=lr)
    else:
        raise NotImplementedError

    if model_name == 'GAN':
        BCE_loss = nn.BCELoss()
        real_label = Variable(torch.Tensor(batch_size, 1).fill_(1.0),
                              requires_grad=False)
        fake_label = Variable(torch.Tensor(batch_size, 1).fill_(0.0),
                              requires_grad=False)

    critic_update = 0
    sigma_list = [1, 2, 4, 8, 16]

    while rho < rho_max:
        for _ in range(epochs):
            for sample, mask in train_loader:
                masked_real_sample = mask_data(sample, mask, tau)
                input_noise = Variable(
                    torch.Tensor(np.random.normal(0, 1,
                                                  (batch_size, node_num))))
                generated_sample = Gen(input_noise)
                h_matrix = Gen.adj_A.clone()
                shuffle_mask = mask[torch.randperm(batch_size)]
                masked_generated_sample = mask_data(generated_sample,
                                                    shuffle_mask, tau)
                masked_fake_sample = masked_generated_sample.detach()

                Disc_optimizer.zero_grad()
                if model_name == 'GAN':
                    Real_sample_loss = BCE_loss(Disc(masked_real_sample),
                                                real_label)
                    Fake_sample_loss = BCE_loss(Disc(masked_fake_sample),
                                                fake_label)
                    Disc_loss = Real_sample_loss + Fake_sample_loss
                elif model_name == 'WGAN':
                    Disc_loss = -torch.mean(
                        Disc(masked_real_sample)) + torch.mean(
                            Disc(masked_fake_sample))
                else:
                    raise NotImplementedError

                Disc_loss.backward()
                Disc_optimizer.step()
                critic_update += 1

                if critic_update == n_critic:

                    critic_update = 0
                    for para in Disc.parameters():
                        para.requires_grad_(False)

                    Gen_optimizer.zero_grad()

                    if model_name == 'GAN':
                        Adv_Gen_loss = BCE_loss(Disc(masked_generated_sample),
                                                real_label)
                    elif model_name == 'WGAN':
                        Adv_Gen_loss = -torch.mean(
                            Disc(masked_generated_sample))
                    else:
                        raise NotImplementedError

                    mmd_loss = lambda_mmd * F.relu(
                        mix_rbf_mmd2(masked_generated_sample,
                                     masked_real_sample, sigma_list))
                    h_loss_temp = acyclic_loss(h_matrix, node_num)
                    sparse_loss = lambda_sparse * torch.sum(
                        torch.abs(h_matrix))
                    penalty_loss = alpha * h_loss_temp + 0.5 * rho * h_loss_temp * h_loss_temp

                    Total_Gen_loss = Adv_Gen_loss + mmd_loss + penalty_loss + sparse_loss
                    Total_Gen_loss.backward()
                    Gen_optimizer.step()

                    for para in Disc.parameters():
                        para.requires_grad_(True)

        h_matrix_d = Gen.adj_A.detach()
        h_new = acyclic_loss(h_matrix_d, node_num)
        print('-------inner optimization (300 epochs finished)---------')
        print('The current h_loss is', h_new, 'Rho is', rho)
        print('weight max', torch.max(h_matrix_d), torch.min(h_matrix_d))
        scale = Gen.scale.detach()
        print('scale_abs max', torch.max(abs(scale)), 'scale_abs min',
              torch.min(abs(scale)))
        if h_new > 0.25 * h_loss:
            rho *= 10
        else:
            break
    alpha += rho * h_new
    return rho, alpha, h_new, h_matrix_d