def misgan_impute(args, data_gen, mask_gen, imputer, data_critic, mask_critic, impu_critic, data, output_dir, checkpoint=None): n_critic = args.n_critic gp_lambda = args.gp_lambda batch_size = args.batch_size nz = args.n_latent epochs = args.epoch plot_interval = args.plot_interval save_model_interval = args.save_interval alpha = args.alpha beta = args.beta gamma = args.gamma tau = args.tau update_all_networks = not args.imputeronly gen_data_dir = mkdir(output_dir / 'img') gen_mask_dir = mkdir(output_dir / 'mask') impute_dir = mkdir(output_dir / 'impute') log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.workers) n_batch = len(data_loader) data_shape = data[0][0].shape data_noise = torch.FloatTensor(batch_size, nz).to(device) mask_noise = torch.FloatTensor(batch_size, nz).to(device) impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device) # Interpolation coefficient eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) # For computing gradient penalty ones = torch.ones(batch_size).to(device) lrate = 1e-4 imputer_lrate = 2e-4 data_gen_optimizer = optim.Adam(data_gen.parameters(), lr=lrate, betas=(.5, .9)) mask_gen_optimizer = optim.Adam(mask_gen.parameters(), lr=lrate, betas=(.5, .9)) imputer_optimizer = optim.Adam(imputer.parameters(), lr=imputer_lrate, betas=(.5, .9)) data_critic_optimizer = optim.Adam(data_critic.parameters(), lr=lrate, betas=(.5, .9)) mask_critic_optimizer = optim.Adam(mask_critic.parameters(), lr=lrate, betas=(.5, .9)) impu_critic_optimizer = optim.Adam(impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9)) update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps, ones, gp_lambda) update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) update_impu_critic = CriticUpdater(impu_critic, impu_critic_optimizer, eps, ones, gp_lambda) start_epoch = 0 critic_updates = 0 log = defaultdict(list) if args.resume: data_gen.load_state_dict(checkpoint['data_gen']) mask_gen.load_state_dict(checkpoint['mask_gen']) imputer.load_state_dict(checkpoint['imputer']) data_critic.load_state_dict(checkpoint['data_critic']) mask_critic.load_state_dict(checkpoint['mask_critic']) impu_critic.load_state_dict(checkpoint['impu_critic']) data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) imputer_optimizer.load_state_dict(checkpoint['imputer_opt']) data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt']) start_epoch = checkpoint['epoch'] critic_updates = checkpoint['critic_updates'] log = checkpoint['log'] elif args.pretrain: pretrain = torch.load(args.pretrain, map_location='cpu') data_gen.load_state_dict(pretrain['data_gen']) mask_gen.load_state_dict(pretrain['mask_gen']) data_critic.load_state_dict(pretrain['data_critic']) mask_critic.load_state_dict(pretrain['mask_critic']) if 'imputer' in pretrain: imputer.load_state_dict(pretrain['imputer']) impu_critic.load_state_dict(pretrain['impu_critic']) with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) def save_model(path, epoch, critic_updates=0): torch.save( { 'data_gen': data_gen.state_dict(), 'mask_gen': mask_gen.state_dict(), 'imputer': imputer.state_dict(), 'data_critic': data_critic.state_dict(), 'mask_critic': mask_critic.state_dict(), 'impu_critic': impu_critic.state_dict(), 'data_gen_opt': data_gen_optimizer.state_dict(), 'mask_gen_opt': mask_gen_optimizer.state_dict(), 'imputer_opt': imputer_optimizer.state_dict(), 'data_critic_opt': data_critic_optimizer.state_dict(), 'mask_critic_opt': mask_critic_optimizer.state_dict(), 'impu_critic_opt': impu_critic_optimizer.state_dict(), 'epoch': epoch + 1, 'critic_updates': critic_updates, 'log': log, 'args': args, }, str(path)) sns.set() start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0 for real_data, real_mask, _, index in data_loader: # Assume real_data and real_mask have the same number of channels. # Could be modified to handle multi-channel images and # single-channel masks. real_mask = real_mask.float()[:, None] real_data = real_data.to(device) real_mask = real_mask.to(device) masked_real_data = mask_data(real_data, real_mask, tau) # Update discriminators' parameters data_noise.normal_() fake_data = data_gen(data_noise) impu_noise.uniform_() imputed_data = imputer(real_data, real_mask, impu_noise) masked_imputed_data = mask_data(real_data, real_mask, imputed_data) if update_all_networks: mask_noise.normal_() fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) update_data_critic(masked_real_data, masked_fake_data) update_mask_critic(real_mask, fake_mask) sum_data_loss += update_data_critic.loss_value sum_mask_loss += update_mask_critic.loss_value update_impu_critic(fake_data, masked_imputed_data) sum_impu_loss += update_impu_critic.loss_value critic_updates += 1 if critic_updates == n_critic: critic_updates = 0 # Update generators' parameters if update_all_networks: for p in data_critic.parameters(): p.requires_grad_(False) for p in mask_critic.parameters(): p.requires_grad_(False) for p in impu_critic.parameters(): p.requires_grad_(False) impu_noise.uniform_() imputed_data = imputer(real_data, real_mask, impu_noise) masked_imputed_data = mask_data(real_data, real_mask, imputed_data) impu_loss = -impu_critic(masked_imputed_data).mean() if update_all_networks: data_noise.normal_() fake_data = data_gen(data_noise) mask_noise.normal_() fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) data_loss = -data_critic(masked_fake_data).mean() mask_loss = -mask_critic(fake_mask).mean() mask_gen.zero_grad() (mask_loss + data_loss * alpha).backward(retain_graph=True) mask_gen_optimizer.step() data_noise.normal_() fake_data = data_gen(data_noise) mask_noise.normal_() fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) data_loss = -data_critic(masked_fake_data).mean() data_gen.zero_grad() (data_loss + impu_loss * beta).backward(retain_graph=True) data_gen_optimizer.step() imputer.zero_grad() if gamma > 0: imputer_mismatch_loss = mask_norm( (imputed_data - real_data)**2, real_mask) (impu_loss + imputer_mismatch_loss * gamma).backward() else: impu_loss.backward() imputer_optimizer.step() if update_all_networks: for p in data_critic.parameters(): p.requires_grad_(True) for p in mask_critic.parameters(): p.requires_grad_(True) for p in impu_critic.parameters(): p.requires_grad_(True) if update_all_networks: mean_data_loss = sum_data_loss / n_batch mean_mask_loss = sum_mask_loss / n_batch log['data loss', 'data_loss'].append(mean_data_loss) log['mask loss', 'mask_loss'].append(mean_mask_loss) mean_impu_loss = sum_impu_loss / n_batch log['imputer loss', 'impu_loss'].append(mean_impu_loss) if plot_interval > 0 and (epoch + 1) % plot_interval == 0: if update_all_networks: print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format( epoch, mean_data_loss, mean_mask_loss, mean_impu_loss)) else: print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss)) filename = f'{epoch:04d}.png' with torch.no_grad(): data_gen.eval() mask_gen.eval() imputer.eval() data_noise.normal_() mask_noise.normal_() data_samples = data_gen(data_noise) plot_samples(data_samples, str(gen_data_dir / filename)) mask_samples = mask_gen(mask_noise) plot_samples(mask_samples, str(gen_mask_dir / filename)) # Plot imputation results impu_noise.uniform_() imputed_data = imputer(real_data, real_mask, impu_noise) imputed_data = mask_data(real_data, real_mask, imputed_data) if hasattr(data, 'mask_info'): bbox = [data.mask_info[idx] for idx in index] else: bbox = None plot_grid(imputed_data, bbox, gap=2, save_file=str(impute_dir / filename)) data_gen.train() mask_gen.train() imputer.train() for (name, shortname), trace in log.items(): fig, ax = plt.subplots(figsize=(6, 4)) ax.plot(trace) ax.set_ylabel(name) ax.set_xlabel('epoch') fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) plt.close(fig) if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0: save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) epoch_end = time.time() time_elapsed = epoch_end - start epoch_time = epoch_end - epoch_start epoch_start = epoch_end with (log_dir / 'epoch-time.txt').open('a') as f: print(epoch, epoch_time, time_elapsed, file=f) save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) print(output_dir)
param_str = "SVM" # --------- CVPR model ---------- elif model_type in ["tCNN", "ED-TCN", "DilatedTCN", "TDNN", "LSTM"]: # Go from y_t = {1...C} to one-hot vector (e.g. y_t = [0, 0, 1, 0]) Y_train = [np_utils.to_categorical(y, n_classes) for y in y_train] Y_test = [np_utils.to_categorical(y, n_classes) for y in y_test] # In order process batches simultaneously all data needs to be of the same length # So make all same length and mask out the ends of each. n_layers = len(n_nodes) max_len = max(np.max(train_lengths), np.max(test_lengths)) max_len = int(np.ceil(max_len / (2**n_layers)))*2**n_layers print("Max length:", max_len) X_train_m, Y_train_, M_train = utils.mask_data(X_train, Y_train, max_len, mask_value=-1) X_test_m, Y_test_, M_test = utils.mask_data(X_test, Y_test, max_len, mask_value=-1) if model_type == "tCNN": model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, max_len, causal=causal, return_param_str=True) elif model_type == "ED-TCN": model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, activation='norm_relu', return_param_str=True) # model, param_str = tf_models.ED_TCN_atrous(n_nodes, conv, n_classes, n_feat, max_len, # causal=causal, activation='norm_relu', return_param_str=True) elif model_type == "TDNN": model, param_str = tf_models.TimeDelayNeuralNetwork(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, activation='tanh', return_param_str=True) elif model_type == "DilatedTCN": model, param_str = tf_models.Dilated_TCN(n_feat, n_classes, n_nodes[0], L, B, max_len=max_len,
n_classes = len(actions_dict) train_lengths = [x.shape[0] for x in X_train] test_lengths = [x.shape[0] for x in X_test] n_train = len(X_train) n_test = len(X_test) n_feat = 400 n_layers = len(n_nodes) max_len = max(np.max(train_lengths), np.max(test_lengths)) max_len = int(np.ceil(max_len / (2**n_layers)))*2**n_layers print("Max length:", max_len) # pdb.set_trace() X_train_m, Y_train_, M_train = utils.mask_data(X_train, y_train, max_len, mask_value=0) X_test_m, M_test = utils.mask_test_data(X_test, max_len, mask_value=0) model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, activation='norm_relu', return_param_str=True) # model, param_str = tf_models.temporal_convs_linear(n_nodes[0], conv, n_classes, n_feat, # max_len, causal=causal, return_param_str=True) # elif model_type == "ED-TCN": # model, param_str = tf_models.ED_TCN(n_nodes, conv, n_classes, n_feat, max_len, causal=causal, # activation='norm_relu', return_param_str=True) # # model, param_str = tf_models.ED_TCN_atrous(n_nodes, conv, n_classes, n_feat, max_len, # # causal=causal, activation='norm_relu', return_param_str=True) # elif model_type == "TDNN": # model, param_str = tf_models.TimeDelayNeuralNetwork(n_nodes, conv, n_classes, n_feat, max_len, # causal=causal, activation='tanh', return_param_str=True)
type=int, default=10, required=False, help="Random seed") args = parser.parse_args() if __name__ == '__main__': print('Read and preprocess') X_train, X_test, y_train, y_test = read_preprocess(args.dataset) print('Train Model') model = train_model(args.dataset, X_train, y_train, _seed=args.seed) print('Mask data') X_test_masked, mask_data_filling = mask_data(X_test, X_train.shape[1], args.seed) if args.method == 'TNBQ': print('The Next Best Question') K = int(args.parameters[0]) res_lst = the_next_best_question(X_train, X_test_masked, y_test, model, X_train.shape[1], mask_data_filling, K) save_results(res_lst, "The Next Best Question", args.parameters, args.seed, args.dataset.title()) if args.method.lower() == 'global': print('Global Feature Importance Acquisition') res_lst = global_shap_baseline(X_train, X_test_masked, y_test, model, X_train.shape[1], mask_data_filling) save_results(res_lst, "Global Feature Importance Acquisition",
def misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, output_dir, checkpoint=None): n_critic = args.n_critic gp_lambda = args.gp_lambda batch_size = args.batch_size nz = args.n_latent epochs = args.epoch plot_interval = args.plot_interval save_interval = args.save_interval alpha = args.alpha tau = args.tau gen_data_dir = mkdir(output_dir / 'img') gen_mask_dir = mkdir(output_dir / 'mask') log_dir = mkdir(output_dir / 'log') model_dir = mkdir(output_dir / 'model') data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True) n_batch = len(data_loader) data_noise = torch.FloatTensor(batch_size, nz).to(device) mask_noise = torch.FloatTensor(batch_size, nz).to(device) # Interpolation coefficient eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) # For computing gradient penalty ones = torch.ones(batch_size).to(device) lrate = 1e-4 # lrate = 1e-5 data_gen_optimizer = optim.Adam(data_gen.parameters(), lr=lrate, betas=(.5, .9)) mask_gen_optimizer = optim.Adam(mask_gen.parameters(), lr=lrate, betas=(.5, .9)) data_critic_optimizer = optim.Adam(data_critic.parameters(), lr=lrate, betas=(.5, .9)) mask_critic_optimizer = optim.Adam(mask_critic.parameters(), lr=lrate, betas=(.5, .9)) update_data_critic = CriticUpdater(data_critic, data_critic_optimizer, eps, ones, gp_lambda) update_mask_critic = CriticUpdater(mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) start_epoch = 0 critic_updates = 0 log = defaultdict(list) if checkpoint: data_gen.load_state_dict(checkpoint['data_gen']) mask_gen.load_state_dict(checkpoint['mask_gen']) data_critic.load_state_dict(checkpoint['data_critic']) mask_critic.load_state_dict(checkpoint['mask_critic']) data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) start_epoch = checkpoint['epoch'] critic_updates = checkpoint['critic_updates'] log = checkpoint['log'] with (log_dir / 'gpu.txt').open('a') as f: print(torch.cuda.device_count(), start_epoch, file=f) def save_model(path, epoch, critic_updates=0): torch.save( { 'data_gen': data_gen.state_dict(), 'mask_gen': mask_gen.state_dict(), 'data_critic': data_critic.state_dict(), 'mask_critic': mask_critic.state_dict(), 'data_gen_opt': data_gen_optimizer.state_dict(), 'mask_gen_opt': mask_gen_optimizer.state_dict(), 'data_critic_opt': data_critic_optimizer.state_dict(), 'mask_critic_opt': mask_critic_optimizer.state_dict(), 'epoch': epoch + 1, 'critic_updates': critic_updates, 'log': log, 'args': args, }, str(path)) sns.set() start = time.time() epoch_start = start for epoch in range(start_epoch, epochs): sum_data_loss, sum_mask_loss = 0, 0 for real_data, real_mask, _, _ in data_loader: # Assume real_data and mask have the same number of channels. # Could be modified to handle multi-channel images and # single-channel masks. real_mask = real_mask.float()[:, None] real_data = real_data.to(device) real_mask = real_mask.to(device) masked_real_data = mask_data(real_data, real_mask, tau) # Update discriminators' parameters data_noise.normal_() mask_noise.normal_() fake_data = data_gen(data_noise) fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) update_data_critic(masked_real_data, masked_fake_data) update_mask_critic(real_mask, fake_mask) sum_data_loss += update_data_critic.loss_value sum_mask_loss += update_mask_critic.loss_value critic_updates += 1 if critic_updates == n_critic: critic_updates = 0 # Update generators' parameters for p in data_critic.parameters(): p.requires_grad_(False) for p in mask_critic.parameters(): p.requires_grad_(False) data_noise.normal_() mask_noise.normal_() fake_data = data_gen(data_noise) fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) data_loss = -data_critic(masked_fake_data).mean() data_gen.zero_grad() data_loss.backward() data_gen_optimizer.step() data_noise.normal_() mask_noise.normal_() fake_data = data_gen(data_noise) fake_mask = mask_gen(mask_noise) masked_fake_data = mask_data(fake_data, fake_mask, tau) data_loss = -data_critic(masked_fake_data).mean() mask_loss = -mask_critic(fake_mask).mean() mask_gen.zero_grad() (mask_loss + data_loss * alpha).backward() mask_gen_optimizer.step() for p in data_critic.parameters(): p.requires_grad_(True) for p in mask_critic.parameters(): p.requires_grad_(True) mean_data_loss = sum_data_loss / n_batch mean_mask_loss = sum_mask_loss / n_batch log['data loss', 'data_loss'].append(mean_data_loss) log['mask loss', 'mask_loss'].append(mean_mask_loss) for (name, shortname), trace in log.items(): fig, ax = plt.subplots(figsize=(6, 4)) ax.plot(trace) ax.set_ylabel(name) ax.set_xlabel('epoch') fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) plt.close(fig) if plot_interval > 0 and (epoch + 1) % plot_interval == 0: print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}') filename = f'{epoch:04d}.png' data_gen.eval() mask_gen.eval() with torch.no_grad(): data_noise.normal_() mask_noise.normal_() data_samples = data_gen(data_noise) plot_samples(data_samples, str(gen_data_dir / filename)) mask_samples = mask_gen(mask_noise) plot_samples(mask_samples, str(gen_mask_dir / filename)) data_gen.train() mask_gen.train() if save_interval > 0 and (epoch + 1) % save_interval == 0: save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) epoch_end = time.time() time_elapsed = epoch_end - start epoch_time = epoch_end - epoch_start epoch_start = epoch_end with (log_dir / 'time.txt').open('a') as f: print(epoch, epoch_time, time_elapsed, file=f) save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) print(output_dir)
def get_bold_for_condition(dir_input, num_runs, option_zscore=0): """" This function extracts the bold signal for three conditions. option_zscore = 0 => no z-scoring option_zscore = 1 =>z-score the data Returns: bold values for all conditions for each run. A mean value for the entire run. """ from utils import shift_timing, mask_data, scale_data #Initialize arrays stim_label = [] bold_A = [] bold_B = [] bold_C = [] bold_fix = [] bold_mean_all = [] TR_shift_size = 2 # Number of TRs to shift the extraction of the BOLD signal. maskdir = dir_input masks = ['ROI_Cool'] ### Extract the BOLD Signal for the conditions A, B, C ### print("Processing Start ...") maskfile = (maskdir + "%s.nii.gz" % (masks[0])) mask = nib.load(maskfile) print("Loaded Mask") print(mask.shape) for run in range(1, num_runs + 1): epi_in = (dir_input + "lab1_r0%s.nii.gz" % (run)) stim_label = np.load(dir_input + 'labels_r0%s.npy' % (run)) # Haemodynamic shift label_TR_shifted = shift_timing(stim_label, TR_shift_size) # Get labels for conditions for A, B, C, and baseline fixation. A = np.squeeze(np.argwhere(label_TR_shifted == 1)) B = np.squeeze(np.argwhere(label_TR_shifted == 2)) C = np.squeeze(np.argwhere(label_TR_shifted == 3)) fixation = np.squeeze(np.argwhere(label_TR_shifted == 0)) epi_data = nib.load(epi_in) epi_mask_data = mask_data(epi_data, mask) if option_zscore == 1: epi_maskdata_zscore = scale_data(epi_mask_data) epi_mask_data = epi_maskdata_zscore if run == 1: bold_A = epi_mask_data[A] bold_B = epi_mask_data[B] bold_C = epi_mask_data[C] bold_fix = epi_mask_data[fixation] bold_data_all = epi_mask_data else: bold_A = np.vstack([bold_A, epi_mask_data[A]]) bold_B = np.vstack([bold_B, epi_mask_data[B]]) bold_C = np.vstack([bold_C, epi_mask_data[C]]) bold_fix = np.vstack([bold_fix, epi_mask_data[fixation]]) bold_data_all = np.vstack([bold_data_all, epi_mask_data]) bold_mean_all.append(np.mean(epi_mask_data)) print("Processing Completed") return bold_data_all, bold_mean_all, bold_A, bold_B, bold_C, bold_fix, label_TR_shifted
def Train(train_loader, model_name, optimizer, node_num, rho_max, epochs, lr, batch_size, tau, n_critic, lambda_sparse, lambda_mmd, h_loss, rho, alpha): #---------------network selection------------------- Gen = Linear_Generator(node_num) Disc = Discriminator(node_num, model_name) #-----------------optimizer------------------------- if optimizer == 'ADAM': Gen_optimizer = torch.optim.Adam(Gen.parameters(), lr=lr) Disc_optimizer = torch.optim.Adam(Disc.parameters(), lr=lr) elif optimizer == 'RMS': Gen_optimizer = torch.optim.RMSprop(Gen.parameters(), lr=lr) Disc_optimizer = torch.optim.RMSprop(Disc.parameters(), lr=lr) elif optimizer == 'LBFGS': Gen_optimizer = torch.optim.LBFGS(Gen.parameters(), lr=lr) Disc_optimizer = torch.optim.LBFGS(Disc.parameters(), lr=lr) elif optimizer == 'SGD': Gen_optimizer = torch.optim.SGD(Gen.parameters(), lr=lr) Disc_optimizer = torch.optim.SGD(Disc.parameters(), lr=lr) else: raise NotImplementedError if model_name == 'GAN': BCE_loss = nn.BCELoss() real_label = Variable(torch.Tensor(batch_size, 1).fill_(1.0), requires_grad=False) fake_label = Variable(torch.Tensor(batch_size, 1).fill_(0.0), requires_grad=False) critic_update = 0 sigma_list = [1, 2, 4, 8, 16] while rho < rho_max: for _ in range(epochs): for sample, mask in train_loader: masked_real_sample = mask_data(sample, mask, tau) input_noise = Variable( torch.Tensor(np.random.normal(0, 1, (batch_size, node_num)))) generated_sample = Gen(input_noise) h_matrix = Gen.adj_A.clone() shuffle_mask = mask[torch.randperm(batch_size)] masked_generated_sample = mask_data(generated_sample, shuffle_mask, tau) masked_fake_sample = masked_generated_sample.detach() Disc_optimizer.zero_grad() if model_name == 'GAN': Real_sample_loss = BCE_loss(Disc(masked_real_sample), real_label) Fake_sample_loss = BCE_loss(Disc(masked_fake_sample), fake_label) Disc_loss = Real_sample_loss + Fake_sample_loss elif model_name == 'WGAN': Disc_loss = -torch.mean( Disc(masked_real_sample)) + torch.mean( Disc(masked_fake_sample)) else: raise NotImplementedError Disc_loss.backward() Disc_optimizer.step() critic_update += 1 if critic_update == n_critic: critic_update = 0 for para in Disc.parameters(): para.requires_grad_(False) Gen_optimizer.zero_grad() if model_name == 'GAN': Adv_Gen_loss = BCE_loss(Disc(masked_generated_sample), real_label) elif model_name == 'WGAN': Adv_Gen_loss = -torch.mean( Disc(masked_generated_sample)) else: raise NotImplementedError mmd_loss = lambda_mmd * F.relu( mix_rbf_mmd2(masked_generated_sample, masked_real_sample, sigma_list)) h_loss_temp = acyclic_loss(h_matrix, node_num) sparse_loss = lambda_sparse * torch.sum( torch.abs(h_matrix)) penalty_loss = alpha * h_loss_temp + 0.5 * rho * h_loss_temp * h_loss_temp Total_Gen_loss = Adv_Gen_loss + mmd_loss + penalty_loss + sparse_loss Total_Gen_loss.backward() Gen_optimizer.step() for para in Disc.parameters(): para.requires_grad_(True) h_matrix_d = Gen.adj_A.detach() h_new = acyclic_loss(h_matrix_d, node_num) print('-------inner optimization (300 epochs finished)---------') print('The current h_loss is', h_new, 'Rho is', rho) print('weight max', torch.max(h_matrix_d), torch.min(h_matrix_d)) scale = Gen.scale.detach() print('scale_abs max', torch.max(abs(scale)), 'scale_abs min', torch.min(abs(scale))) if h_new > 0.25 * h_loss: rho *= 10 else: break alpha += rho * h_new return rho, alpha, h_new, h_matrix_d