def train_gan(generator, discriminator, image_loader, num_epochs, batch_size, cuda=True, g_lr=1e-3, d_lr=1e-3, filename_prefix="results", save_gen_images=False): if cuda: dtype = torch.cuda.FloatTensor generator.cuda() discriminator.cuda() else: dtype = torch.FloatTensor iters = 0 d_optimizer = create_optimizer(discriminator, lr=d_lr, betas=(.5, .999)) g_optimizer = create_optimizer(generator, lr=g_lr, betas=(.5, .999)) BCELoss = nn.BCELoss() for epoch in range(num_epochs): for x, _ in image_loader: if x.shape[0] != batch_size: continue real_data = x.type(dtype) z = generate_noise(batch_size).type(dtype) fake_images = generator(z) g_result = discriminator(fake_images).squeeze() # g_cost = BCELoss(g_result, torch.ones(batch_size).type(dtype)) g_cost = torch.mean(g_result) g_cost.backward() g_optimizer.step() g_optimizer.zero_grad() d_optimizer.zero_grad() z = generate_noise(batch_size).type(dtype) fake_images = generator(z) d_spred_fake = discriminator(fake_images).squeeze() d_cost_fake = BCELoss(d_spred_fake, torch.zeros(batch_size).type(dtype)) d_spred_real = discriminator(real_data).squeeze() d_cost_real = BCELoss(d_spred_real, torch.ones(batch_size).type(dtype)) # d_cost = d_cost_real + d_cost_fake d_cost = 0 - torch.mean(d_spred_real - d_spred_fake) d_cost.backward() d_optimizer.step() iters += 1 if save_images: save_images(generator, epoch, iters, filename_prefix) print("Epoch", epoch, "Iter", iters) print("d_cost", d_cost) print("g_cost", g_cost) print("Inception Score", get_inception_score(generator)) return discriminator, generator
def calc_losses_sagan(batch_real, generator, discriminator, optim_generator, optim_discr, device, noise_size): # Calculate loss of according to https://arxiv.org/abs/1805.08318 #Hinge loss for the dicrimnator discriminator.train() generator.train() optim_discr.zero_grad() noise_for_discr = utils.generate_noise(batch_real.size()[0], noise_size).to(device) fake_for_discr = generator(noise_for_discr).to(device) discr_fake_pred = discriminator(fake_for_discr) #Dicrimnator loss is the Hinge loss of real image and fake images disc_fake_loss = relu(1.0 + discr_fake_pred).mean() disc_real_pred = discriminator(batch_real) disc_real_loss = relu(1.0 - disc_real_pred).mean() # total loss is the loss on fake + loss on real total_discr_loss = disc_fake_loss + disc_real_loss total_discr_loss.backward() optim_discr.step() # Generator loss generator.train() #discriminator.eval() optim_generator.zero_grad() noise_gener = utils.generate_noise(batch_real.size()[0], noise_size).to(device) fake = generator(noise_gener).to(device) discr_fake_pred_for_gener = discriminator(fake) # We minimise the dicrimnator mean value on the generator weights # multiply by -1 for gradient ascent gener_loss = -discr_fake_pred_for_gener.mean() # gener_loss = criterion(discr_fake_pred_for_gener, torch.ones_like(discr_fake_pred_for_gener)) gener_loss.backward() optim_generator.step() # Calculate accuracies of the discriminator batch_correct_discr_real = (disc_real_pred.detach().cpu() > 0).float().numpy().sum() batch_correct_discr_fake = (discr_fake_pred.detach().cpu() < 0).float().numpy().sum() batch_accuracy_real = batch_correct_discr_real * 100 / batch_real.size(0) batch_accuracy_fake = batch_correct_discr_fake * 100 / batch_real.size(0) return gener_loss.item(), total_discr_loss.item( ), batch_accuracy_real, batch_accuracy_fake
def train_one_iter_D(self, images): self.optimizerD.zero_grad() # get noise (z) noise = generate_noise(self.batch_size, self.z_dim, type='normal') noise = noise.to(self.device) # get fake and real images real_label = torch.ones(self.batch_size).to(self.device) fake_label = torch.zeros(self.batch_size).to(self.device) # forward and backward fake_images = self.generator(noise) fake_pred, _ = self.discriminator(fake_images) fake_loss = self.adv_loss(fake_pred, fake_label) real_pred, _ = self.discriminator(images) real_loss = self.adv_loss(real_pred, real_label) # accumulate gradient and update D_loss = fake_loss + real_loss D_loss.backward() self.optimizerD.step() # logging D_x = real_pred.mean(0).item() D_G_z = fake_pred.mean(0).item() self.D_x.update(D_x, self.batch_size) self.D_G_z.update(D_G_z, self.batch_size) self.D_loss.update(D_loss, self.batch_size)
def __init__(self, loss_type, netD, netG, device, train_ds, val_ds, lr_D = 0.0002, lr_G = 0.0002, rec_weight = 10, ds_weight = 8, use_rec_feature = False, resample = True, weight_clip = None, use_gradient_penalty = False, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'): self.netD = netD self.netG = netG self.train_ds = train_ds self.val_ds = val_ds self.lr_D = lr_D self.lr_G = lr_G self.device = device self.resample = resample self.weight_clip = weight_clip self.use_gradient_penalty = use_gradient_penalty self.rec_weight = rec_weight self.use_rec_feature = use_rec_feature self.ds_weight = ds_weight self.nz = self.netG.nz self.fixed_noise = generate_noise(3, self.nz, self.device) self.loss_type = loss_type self.require_type = get_require_type(self.loss_type) self.loss = get_gan_loss(self.device, self.loss_type) self.ds_loss = DSGAN_Loss(self.device, self.nz) self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.9)) self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.9)) self.loss_interval = loss_interval self.image_interval = image_interval self.save_cnt = 0 self.save_img_dir = save_img_dir if(not os.path.exists(self.save_img_dir)): os.makedirs(self.save_img_dir)
def forward(self, noise_init, noise_amp, mode='rand'): x_prev_out = self.body[0](F.pad(noise_init, self.p3d)) for idx, block in enumerate(self.body[1:], 1): x_prev_out = torch.tanh(x_prev_out) # Upscale x_prev_out_up = utils.upscale(x_prev_out, idx, self.opt) # Add noise if "random" sampling, else, add no noise is "reconstruction" mode if mode == 'rand': x_prev_out_up_2 = utils.interpolate_3D( x_prev_out, size=[ x_prev_out_up.shape[-3] + (self.opt.num_layer + 2) * 2, x_prev_out_up.shape[-2] + (self.opt.num_layer + 2) * 2, x_prev_out_up.shape[-1] + (self.opt.num_layer + 2) * 2 ]) noise = utils.generate_noise(ref=x_prev_out_up_2) x_prev = block(x_prev_out_up_2 + noise * noise_amp[idx]) else: x_prev = block(F.pad(x_prev_out_up, self.p3d)) x_prev_out = x_prev + x_prev_out_up out = torch.tanh(x_prev_out) return out
def __init__(self, opts, data, weights): # Create a new session with session.graph = default graph self._session = tf.Session() self._trained = False self._data = data self._data_weights = np.copy(weights) # Latent noise sampled ones to apply decoder while training self._noise_for_plots = utils.generate_noise(opts, 500) # Placeholders self._real_points_ph = None self._noise_ph = None # Main operations # FIX self._loss = None self._loss_reconstruct = None self._loss_kl = None self._generated = None self._reconstruct_x = None # Optimizers self.optim = None with self._session.as_default(), self._session.graph.as_default(): logging.error('Building the graph...') self._build_model_internal(opts) # Make sure AdamOptimizer, if used in the Graph, is defined before # calling global_variables_initializer(). init = tf.global_variables_initializer() self._session.run(init)
def __init__(self, netD, netG, device, train_dl, lr_D = 0.0002, lr_G = 0.0002, beta1 = 0.5, loss_interval = 50, image_interval = 50, snapshot_interval = None, save_img_dir = 'saved_images/', save_snapshot_dir = 'saved_snapshots', resample = False): self.netD = netD self.netG = netG self.train_dl = train_dl self.lr_D = lr_D self.lr_G = lr_G self.train_iteration_per_epoch = len(self.train_dl) self.device = device self.resample = resample self.special = None self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (beta1, 0.999)) self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (beta1, 0.999)) self.real_label = 1 self.fake_label = 0 self.nz = self.netG.nz self.fixed_noise = generate_noise(16, self.nz, self.device) self.loss_interval = loss_interval self.image_interval = image_interval self.snapshot_interval = snapshot_interval self.errD_records = [] self.errG_records = [] self.save_cnt = 0 self.save_img_dir = save_img_dir self.save_snapshot_dir = save_snapshot_dir if(not os.path.exists(self.save_img_dir)): os.makedirs(self.save_img_dir) if(not os.path.exists(self.save_snapshot_dir)): os.makedirs(self.save_snapshot_dir)
def train_one_iter_G(self, images=None): self.optimizerG.zero_grad() # get noise (z) noise = generate_noise(self.batch_size, self.z_dim, type='normal') noise = noise.to(self.device) # get real labels ( modified loss, minimize -log(D(G(z)) ) real_label = torch.ones(self.batch_size).to(self.device) fake_images = self.generator(noise) if self.loss_type == 'adversarial': # vanilla adversarial loss fake_pred, _ = self.discriminator(fake_images) G_loss = self.adv_loss(fake_pred, real_label) G_loss.backward() elif self.loss_type == 'feature_matching': _, fmaps_fake = self.discriminator(fake_images) _, fmaps_real = self.discriminator(images) # get feature map statistics of last layer fmap_fake = fmaps_fake[-1].mean(0) fmap_real = fmaps_real[-1].mean(0).detach() # treat as constant G_loss = self.fm_loss(fmap_fake, fmap_real) G_loss.backward() elif self.loss_type == 'both': raise ValueError("los_type = both is not yet implemented") else: raise ValueError("invalid loss_type.") self.optimizerG.step() self.G_loss.update(G_loss, self.batch_size)
def _sample_internal(self, opts, num): """Sample from the trained GAN model. """ noise = utils.generate_noise(opts, num) sample = self._run_batch(opts, self._generated, self._noise_ph, noise, self._is_training_ph, False) return sample
def __init__(self, netD, netG, device, train_dl, lr_D=0.0002, lr_G=0.0002, n_critic=5, lambd=10, loss_interval=50, image_interval=50, snapshot_interval=None, save_img_dir='saved_images/', save_snapshot_dir='saved_snapshots', resample=None): self.netD = netD self.netG = netG self.train_dl = train_dl self.lr_D = lr_D self.lr_G = lr_G self.n_critic = n_critic self.lambd = lambd self.train_iteration_per_epoch = len(self.train_dl) self.device = device self.resample = resample self.special = None self.optimizerD = optim.Adam(self.netD.parameters(), lr=self.lr_D, betas=(0, 0.9)) self.optimizerG = optim.Adam(self.netG.parameters(), lr=self.lr_G, betas=(0, 0.9)) self.real_label = 1 self.fake_label = 0 self.nz = self.netG.nz self.fixed_noise = generate_noise(16, self.nz, self.device) self.loss_interval = loss_interval self.image_interval = image_interval self.snapshot_interval = snapshot_interval self.errD_records = [] self.errG_records = [] self.w_dist_records = [] self.save_cnt = 0 self.save_img_dir = save_img_dir self.save_snapshot_dir = save_snapshot_dir if (not os.path.exists(self.save_img_dir)): os.makedirs(self.save_img_dir) if (not os.path.exists(self.save_snapshot_dir)): os.makedirs(self.save_snapshot_dir) assert ( self.resample is not None ), "Resample parameter is unnecessary for wgan_gp because it already resamples by default."
def draw_concat(Gs, Zs, reals, NoiseAmp, in_s, mode, m_noise, m_image, opt): G_z = in_s if len(Gs) > 0: if mode == 'rand': count = 0 pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': pad_noise = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): if count == 0: z = functions.generate_noise([ 1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) z = z.expand(1, 3, z.shape[2], z.shape[3]) else: z = functions.generate_noise([ opt.nc_z, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise ], device=opt.device) z = m_noise(z) G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * z + G_z G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] count += 1 if mode == 'rec': count = 0 for G, Z_opt, real_curr, real_next, noise_amp in zip( Gs, Zs, reals, reals[1:], NoiseAmp): G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] G_z = m_image(G_z) z_in = noise_amp * Z_opt + G_z G_z = G(z_in.detach(), G_z) G_z = imresize(G_z, 1 / opt.scale_factor, opt) G_z = G_z[:, :, 0:real_next.shape[2], 0:real_next.shape[3]] #if count != (len(Gs)-1): # G_z = m_image(G_z) count += 1 return G_z
def get_shapes_with_noise(initial_shape: np.array): noised_shapes = [] for i in range(5, 55, 5): noise_level = i / 100 noised_shapes.append((generate_noise(initial_shape, noise_level), noise_level)) return noised_shapes
def before_train(self): # set logging level logging.basicConfig(filename='train.log', level=logging.INFO) # initialize WandbLogger self.wandb_logger = WandbLogger(project=self.cfg.project) # initialize a set of noise vectors which will be used to # visualize generator's progress self.noise = generate_noise(16, self.z_dim).to(self.device) os.makedirs(self.cfg.save_folder, exist_ok=True)
def calc_inception_score(device, noise_size, generator, splits=10, eval_size=50000): #Calculate Inception score #Based on the classification output of the Inception model trained on Imagenet batch_size = 32 inception = inception_v3(pretrained=True) inception.eval() inception = inception.to(device) generator.eval() num_batches = eval_size // batch_size #claculate incpetion predictions in generated images all_fake_predictions = None for batch in range(num_batches): noise = utils.generate_noise(batch_size, noise_size).to(device) fake_batch = generator(noise).to(device).detach() prepared_fake_batch = F.interpolate(fake_batch, INCEPTION_SIZE) inception_logits = inception(prepared_fake_batch) predictions = F.softmax(inception_logits, dim=-1).detach().cpu().numpy() if all_fake_predictions is None: all_fake_predictions = predictions else: all_fake_predictions = np.concatenate( (all_fake_predictions, predictions)) eval_size_equal = ( len(all_fake_predictions) // splits ) * splits #Need to have an equal number of splits so cut off end if not #Calculate Inception score . Based on KL Divergence. all_scores = [] for split in np.split(all_fake_predictions[:eval_size_equal], splits): split_scores = [] prob_y = np.repeat(np.mean(split, axis=0, keepdims=True), len(split), axis=0) kl_div = stats.entropy(split, prob_y, axis=1) split_scores = np.exp(np.mean(kl_div)) all_scores.append(split_scores) inception_score = np.mean(all_scores) generator.train() return inception_score
def __init__(self, netD, netG, n_classes, device, train_dl, lr_D=0.0002, lr_G=0.0002, loss_interval=50, image_interval=50, snapshot_interval=None, save_img_dir='saved_images/', save_snapshot_dir='saved_snapshots'): self.netD = netD self.netG = netG self.n_classes = n_classes self.train_dl = train_dl self.lr_D = lr_D self.lr_G = lr_G self.train_iteration_per_epoch = len(self.train_dl) self.device = device self.optimizerD = optim.RMSprop(self.netD.parameters(), lr=self.lr_D) self.optimizerG = optim.RMSprop(self.netG.parameters(), lr=self.lr_G) self.real_label = 1 self.fake_label = 0 self.nz = self.netG.nz self.fixed_noise = generate_noise(self.n_classes, self.nz, self.device) self.fixed_one_hot_labels = torch.diagflat(torch.ones( self.n_classes)).to(self.device) self.loss_interval = loss_interval self.image_interval = image_interval self.snapshot_interval = snapshot_interval self.errD_records = [] self.errG_records = [] self.save_cnt = 0 self.save_img_dir = save_img_dir self.save_snapshot_dir = save_snapshot_dir if (not os.path.exists(self.save_img_dir)): os.makedirs(self.save_img_dir) if (not os.path.exists(self.save_snapshot_dir)): os.makedirs(self.save_snapshot_dir)
def __init__(self, loss_type, netD, netG, device, train_dl, lr_D = 0.0002, lr_G = 0.0002, resample = True, weight_clip = None, use_gradient_penalty = False, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'): self.loss_type = loss_type self.loss_dict = {'SGAN':SGAN, 'LSGAN':LSGAN, 'HINGEGAN':HINGEGAN, 'WGAN':WGAN, 'RASGAN':RASGAN, 'RALSGAN':RALSGAN, 'RAHINGEGAN':RAHINGEGAN, 'QPGAN':QPGAN} if(loss_type == 'SGAN' or loss_type == 'LSGAN' or loss_type == 'HINGEGAN' or loss_type == 'WGAN'): self.require_type = 0 self.loss = self.loss_dict[self.loss_type](device) elif(loss_type == 'RASGAN' or loss_type == 'RALSGAN' or loss_type == 'RAHINGEGAN'): self.require_type = 1 self.loss = self.loss_dict[self.loss_type](device) elif(loss_type == 'QPGAN'): self.require_type = 2 self.loss = self.loss_dict[self.loss_type](device, 'L1') else: self.require_type = -1 self.netD = netD self.netG = netG self.train_dl = train_dl self.lr_D = lr_D self.lr_G = lr_G self.train_iteration_per_epoch = len(self.train_dl) self.device = device self.resample = resample self.weight_clip = weight_clip self.use_gradient_penalty = use_gradient_penalty self.special = None self.optimizerD = optim.Adam(self.netD.parameters(), lr = self.lr_D, betas = (0, 0.9)) self.optimizerG = optim.Adam(self.netG.parameters(), lr = self.lr_G, betas = (0, 0.9)) self.real_label = 1 self.fake_label = 0 self.nz = self.netG.nz self.fixed_noise = generate_noise(49, self.nz, self.device) self.loss_interval = loss_interval self.image_interval = image_interval self.errD_records = [] self.errG_records = [] self.save_cnt = 0 self.save_img_dir = save_img_dir if(not os.path.exists(self.save_img_dir)): os.makedirs(self.save_img_dir)
def refinement_layers(self, start_idx, x_prev_out, noise_amp, mode): for idx, block in enumerate(self.body[start_idx:], start_idx): if self.opt.vae_levels == idx + 1: x_prev_out.detach_() # Upscale x_prev_out_up = utils.upscale(x_prev_out, idx + 1, self.opt) # Add noise if "random" sampling, else, add no noise is "reconstruction" mode if mode == 'rand': noise = utils.generate_noise(ref=x_prev_out_up) x_prev = block(x_prev_out_up + noise * noise_amp[idx + 1]) else: x_prev = block(x_prev_out_up) x_prev_out = torch.tanh(x_prev + x_prev_out_up) return x_prev_out
def inference(inp, filename): lm, gen = init_inference() checkpoint = torch.load(f"{config.OUT_DIR}/checkpoint.pt") ctoi_file = open(f"{config.BASE_DIR}/src/ctoi.txt", "rb") encoding_dict = pickle.load(ctoi_file) ctoi_file.close() # print( # f'Checkpoint Details:\n Trained for: {checkpoint["epoch"]} epochs, Final Generator loss: {checkpoint["gen_loss"]}, Log File: {checkpoint["log_file"]}' # ) lm.load_state_dict(checkpoint["lm"]) gen.load_state_dict(checkpoint["gen"]) test = preprocess_labels([inp] * config.BATCH_SIZE, encoding_dict) with torch.no_grad(): zin = generate_noise(config.Z_LEN, config.BATCH_SIZE, device) gin = lm(test.to(device)) gout = gen(zin, gin) tgrid = torchvision.utils.make_grid(gout.detach().cpu(), nrow=4) imshow(tgrid, f"{config.OUT_DIR}/inference/{filename}.png") print(f'Inference Finished. Check "out" directory for {filename}.png')
def inference_tb(inp, writer): lm, gen = init_inference() checkpoint = torch.load(f"{config.OUT_DIR}/checkpoint.pt") ctoi_file = open(f"{config.BASE_DIR}/src/ctoi.txt", "rb") encoding_dict = pickle.load(ctoi_file) ctoi_file.close() # print( # f'Checkpoint Details:\n Trained for: {checkpoint["epoch"]} epochs, Final Generator loss: {checkpoint["gen_loss"]}, Log File: {checkpoint["log_file"]}' # ) lm.load_state_dict(checkpoint["lm"]) gen.load_state_dict(checkpoint["gen"]) test = preprocess_labels([inp] * config.BATCH_SIZE, encoding_dict) with torch.no_grad(): # lm.eval() # gen.eval() zin = generate_noise(config.Z_LEN, config.BATCH_SIZE, device) gin = lm(test.to(device)) gout = gen(zin, gin) tgrid = torchvision.utils.make_grid(gout.detach().cpu(), nrow=4) writer.add_image(str(checkpoint["epoch"]), tgrid)
def train(self, num_epoch): for epoch in range(num_epoch): for i, data in enumerate(tqdm(self.train_dl)): self.netD.zero_grad() real_images = data[0].to(self.device) bs = real_images.size(0) noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) c_xr = self.netD(real_images) c_xr = c_xr.view(-1) c_xf = self.netD(fake_images.detach()) c_xf = c_xf.view(-1) if(self.require_type == 0 or self.require_type == 1): errD = self.loss.d_loss(c_xr, c_xf) elif(self.require_type == 2): errD = self.loss.d_loss(c_xr, c_xf, real_images, fake_images) if(self.use_gradient_penalty != False): errD += self.use_gradient_penalty * self.gradient_penalty(real_images, fake_images) errD.backward() self.optimizerD.step() if(self.weight_clip != None): for param in self.netD.parameters(): param.data.clamp_(-self.weight_clip, self.weight_clip) self.netG.zero_grad() if(self.resample): noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) if(self.require_type == 0): c_xf = self.netD(fake_images) c_xf = c_xf.view(-1) errG = self.loss.g_loss(c_xf) if(self.require_type == 1 or self.require_type == 2): c_xr = self.netD(real_images) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) c_xf = self.netD(fake_images) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) errG = self.loss.g_loss(c_xr, c_xf) errG.backward() self.optimizerG.step() self.errD_records.append(float(errD)) self.errG_records.append(float(errG)) if(i % self.loss_interval == 0): print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %(epoch+1, num_epoch, i+1, self.train_iteration_per_epoch, errD, errG)) if(i % self.image_interval == 0): if(self.special == None): sample_images_list = get_sample_images_list('Unsupervised', (self.fixed_noise, self.netG)) plot_img = get_display_samples(sample_images_list, 7, 7) cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg') self.save_cnt += 1 cv2.imwrite(cur_file_name, plot_img) elif(self.special == 'Wave'): sample_audios_list = get_sample_images_list('Unsupervised_Audio', (self.fixed_noise, self.netG)) plot_fig = plot_multiple_spectrograms(sample_audios_list, 7, 7, freq = 16000) cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf()
def _train_internal(self, opts): """Train a VAE model. """ batches_num = self._data.num_points / opts['batch_size'] train_size = self._data.num_points num_plot = 320 sample_prev = np.zeros([num_plot] + list(self._data.data_shape)) l2s = [] counter = 0 decay = 1. logging.error('Training VAE') for _epoch in xrange(opts["gan_epoch_num"]): if opts['decay_schedule'] == "manual": if _epoch == 30: decay = decay / 2. if _epoch == 50: decay = decay / 5. if _epoch == 100: decay = decay / 10. if _epoch > 0 and _epoch % opts['save_every_epoch'] == 0: os.path.join(opts['work_dir'], opts['ckpt_dir']) self._saver.save(self._session, os.path.join(opts['work_dir'], opts['ckpt_dir'], 'trained-pot'), global_step=counter) for _idx in xrange(batches_num): # logging.error('Step %d of %d' % (_idx, batches_num ) ) data_ids = np.random.choice(train_size, opts['batch_size'], replace=False, p=self._data_weights) batch_images = self._data.data[data_ids].astype(np.float) batch_noise = utils.generate_noise(opts, opts['batch_size']) _, loss, loss_kl, loss_reconstruct = self._session.run( [self._optim, self._loss, self._loss_kl, self._loss_reconstruct], feed_dict={self._real_points_ph: batch_images, self._noise_ph: batch_noise, self._lr_decay_ph: decay, self._is_training_ph: True}) counter += 1 if opts['verbose'] and counter % opts['plot_every'] == 0: debug_str = 'Epoch: %d/%d, batch:%d/%d' % ( _epoch+1, opts['gan_epoch_num'], _idx+1, batches_num) debug_str += ' [L=%.2g, Recon=%.2g, KLQ=%.2g]' % ( loss, loss_reconstruct, loss_kl) logging.error(debug_str) if opts['verbose'] and counter % opts['plot_every'] == 0: metrics = Metrics() points_to_plot = self._run_batch( opts, self._generated, self._noise_ph, self._noise_for_plots[0:num_plot], self._is_training_ph, False) l2s.append(np.sum((points_to_plot - sample_prev)**2)) metrics.l2s = l2s[:] metrics.make_plots( opts, counter, None, points_to_plot, prefix='sample_e%04d_mb%05d_' % (_epoch, _idx)) reconstructed = self._session.run( self._reconstruct_x, feed_dict={self._real_points_ph: batch_images, self._is_training_ph: False}) metrics.l2s = None metrics.make_plots( opts, counter, None, reconstructed, prefix='reconstr_e%04d_mb%05d_' % (_epoch, _idx)) if opts['early_stop'] > 0 and counter > opts['early_stop']: break if _epoch > 0: os.path.join(opts['work_dir'], opts['ckpt_dir']) self._saver.save(self._session, os.path.join(opts['work_dir'], opts['ckpt_dir'], 'trained-pot-final'), global_step=counter)
def train(self, num_epoch): criterion = nn.BCELoss() for epoch in range(num_epoch): for i, data in enumerate(tqdm(self.train_dl)): # (1) : maximize log(D(x)) + log(1 - D(G(z))) # also means minimize (-log(D(x))) + (-log(1 - D(G(z)))) self.netD.zero_grad() # first, calculate -log(D(x)) and its gradients using real images # real images (bs, nc, 64, 64) real_images = data[0].to(self.device) bs = real_images.size(0) # real labels (bs) label = torch.full((bs, ), self.real_label, device=self.device) output = self.netD(real_images) # (bs, 1, 1, 1) output = output.view(-1) # (bs) # BCELoss of output(bs), and real label(bs) errD_real = criterion(output, label) # -log(D(x)) # calculate the gradients errD_real.backward() # second, calculate -log(1 - D(G(z))) and its gradients using fake images # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64) noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) # fake labels (bs) label.fill_(self.fake_label) output = self.netD(fake_images.detach()) # (bs, 1, 1, 1) output = output.view(-1) # (bs) # BCELoss of output(bs), and fake labels(bs) errD_fake = criterion(output, label) # -log(1 - D(G(z))) # calculate the gradients errD_fake.backward() # calculate the final loss value, (-log(D(x))) + (-log(1 - D(G(z)))) errD = errD_real + errD_fake # update D using the gradients calculated previously self.optimizerD.step() # (2) : maximize log(D(G(z))) # also means minimize -log(D(G(z))) self.netG.zero_grad() if (self.resample): noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) # first, calculate -log(D(G(z))) and its gradients using fake images # real labels (bs) label.fill_(self.real_label) output = self.netD(fake_images) # (bs, 1, 1, 1) output = output.view(-1) # (bs) # BCELoss of output(bs), and real labels(bs) errG = criterion(output, label) # -log(D(G(z))) #calculate the gradients errG.backward() #update G using the gradients calculated previously self.optimizerG.step() self.errD_records.append(float(errD)) self.errG_records.append(float(errG)) if (i % self.loss_interval == 0): print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' % (epoch + 1, num_epoch, i + 1, self.train_iteration_per_epoch, errD, errG)) if (i % self.image_interval == 0): if (self.special == None): sample_images_list = get_sample_images_list( 'Unsupervised', (self.fixed_noise, self.netG)) plot_fig = plot_multiple_images( sample_images_list, 4, 4) cur_file_name = os.path.join( self.save_img_dir, str(self.save_cnt) + ' : ' + str(epoch) + '-' + str(i) + '.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf() elif (self.special == 'Wave'): sample_audios_list = get_sample_images_list( 'Unsupervised_Audio', (self.fixed_noise, self.netG)) plot_fig = plot_multiple_spectrograms( sample_audios_list, 4, 4, freq=16000) cur_file_name = os.path.join( self.save_img_dir, str(self.save_cnt) + ' : ' + str(epoch) + '-' + str(i) + '.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf() if (self.snapshot_interval is not None): if (i % self.snapshot_interval == 0): save( os.path.join( self.save_snapshot_dir, 'Epoch' + str(epoch) + '_' + str(i) + '.state'), self.netD, self.netG, self.optimizerD, self.optimizerG)
def eval(opt, netG): # Re-generate dataset frames fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt) opt.fps = fps opt.td = td opt.fps_index = fps_index # opt.tds.append(opt.td) opt.dataset.generate_frames(opt.scale_idx) torch.save(opt.dataset.frames, os.path.join(opt.saver.eval_dir, "real_full_scale.pth")) if not hasattr(opt, 'Z_init_size'): initial_size = utils.get_scales_by_index(0, opt.scale_factor, opt.stop_scale, opt.img_size) initial_size = [int(initial_size * opt.ar), initial_size] opt.Z_init_size = [ opt.batch_size, opt.latent_dim, opt.td, *initial_size ] # Parallel if opt.device == 'cuda': G_curr = torch.nn.DataParallel(netG) else: G_curr = netG progressbar_args = { "iterable": range(opt.niter), "desc": "Generation scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1), "train": True, "offset": 0, "logging_on_update": False, "logging_on_close": True, "postfix": True } epoch_iterator = tools.create_progressbar(**progressbar_args) iterator = iter(data_loader) random_samples = [] for iteration in epoch_iterator: try: data = next(iterator) except StopIteration: iterator = iter(opt.data_loader) data = next(iterator) if opt.scale_idx > 0: real, real_zero = data real = real.to(opt.device) else: real = data.to(opt.device) noise_init = utils.generate_noise(size=opt.Z_init_size, device=opt.device) # Update progress bar epoch_iterator.set_description( 'Scale [{}/{}], Iteration [{}/{}]'.format( opt.scale_idx + 1, opt.stop_scale + 1, iteration + 1, opt.niter, )) with torch.no_grad(): fake_var = [] fake_vae_var = [] for _ in range(opt.num_samples): noise_init = utils.generate_noise(ref=noise_init) fake, fake_vae = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") fake_var.append(fake) fake_vae_var.append(fake_vae) fake_var = torch.cat(fake_var, dim=0) fake_vae_var = torch.cat(fake_vae_var, dim=0) opt.summary.visualize_video(opt, iteration, real, 'Real') opt.summary.visualize_video(opt, iteration, fake_var, 'Fake var') opt.summary.visualize_video(opt, iteration, fake_vae_var, 'Fake VAE var') random_samples.append(fake_var) random_samples = torch.cat(random_samples, dim=0) torch.save(random_samples, os.path.join(opt.saver.eval_dir, "random_samples.pth")) epoch_iterator.close()
def eval(opt, netG): # Re-generate dataset frames if not hasattr(opt, 'Z_init_size'): initial_size = utils.get_scales_by_index(0, opt.scale_factor, opt.stop_scale, opt.img_size) initial_size = [int(initial_size * opt.ar), initial_size] opt.Z_init_size = [opt.batch_size, opt.latent_dim, *initial_size] # Parallel if opt.device == 'cuda': G_curr = torch.nn.DataParallel(netG) else: G_curr = netG progressbar_args = { "iterable": range(opt.niter), "desc": "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1), "train": True, "offset": 0, "logging_on_update": False, "logging_on_close": True, "postfix": True } epoch_iterator = tools.create_progressbar(**progressbar_args) iterator = iter(data_loader) random_samples = [] for iteration in epoch_iterator: try: data = next(iterator) except StopIteration: iterator = iter(opt.data_loader) data = next(iterator) if opt.scale_idx > 0: real, real_zero = data real = real.to(opt.device) else: real = data.to(opt.device) noise_init = utils.generate_noise(size=opt.Z_init_size, device=opt.device) # Update progress bar epoch_iterator.set_description( 'Scale [{}/{}], Iteration [{}/{}]'.format( opt.scale_idx + 1, opt.stop_scale + 1, iteration + 1, opt.niter, )) G_curr.eval() import numpy as np import sys with torch.no_grad(): fake_var = [] fake_vae_var = [] for _ in range(opt.num_samples): noise_init = utils.generate_noise(ref=noise_init) channel_idxs = np.random.choice(np.arange(0, 128), 127, replace=False) # U = torch.zeros(1, 128, 5).normal_(0, 1).to(noise_init.device) U = torch.zeros(1, 128, 1).to(noise_init.device) U[:, _] = 4 # U[:, :120] = V = torch.zeros(1, 1, 22, 33).to(noise_init.device) # V.bernoulli_(p=0.01) V[:, :, 1:4, 20:32] = 1 # V[:, :, 4:10, 8:10] = 1 V = V.flatten(2) UV = torch.bmm(U, V).view(1, 128, 22, 33) UV = (UV - UV.mean()) / UV.std() # noise_init[:] = 0 # noise_init[:, :, 5:11, 16:18] = _ # noise_init[:, 108, 0:4, 0:4] = 100 # noise_init[:, 21, _:_ + 1, 16:19] = 0.01 # noise_init[:, :, 3:11, 16:18] = -10 / opt.num_samples # normed_z_vae = z_vae / ((z_vae ** 2).sum() + sys.float_info.epsilon) # noise_init = noise_init / ((noise_init ** 2).sum() + sys.float_info.epsilon) noise_init = UV fake, fake_vae = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") fake_var.append(fake) fake_vae_var.append(fake_vae) fake_var = torch.cat(fake_var, dim=0) fake_vae_var = torch.cat(fake_vae_var, dim=0) opt.summary.visualize_image(opt, iteration, real, 'Real') opt.summary.visualize_image(opt, iteration, fake_var, 'Fake var') opt.summary.visualize_image(opt, iteration, fake_vae_var, 'Fake VAE var') random_samples.append(fake_var) random_samples = torch.cat(random_samples, dim=0) from torchvision.utils import save_image save_image(random_samples, 'test.png', normalize=True) torch.save(random_samples, os.path.join(opt.saver.eval_dir, "random_samples.pth")) epoch_iterator.close()
def train_single_scale(netD, netG, reals, Gs, Zs, in_s, NoiseAmp, opt, centers=None): real = reals[len(Gs)] opt.nzx = real.shape[2] #+(opt.ker_size-1)*(opt.num_layer) opt.nzy = real.shape[3] #+(opt.ker_size-1)*(opt.num_layer) opt.receptive_field = opt.ker_size + ((opt.ker_size - 1) * (opt.num_layer - 1)) * opt.stride pad_noise = int(((opt.ker_size - 1) * opt.num_layer) / 2) pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) if opt.mode == 'animation_train': opt.nzx = real.shape[2] + (opt.ker_size - 1) * (opt.num_layer) opt.nzy = real.shape[3] + (opt.ker_size - 1) * (opt.num_layer) pad_noise = 0 m_noise = nn.ZeroPad2d(int(pad_noise)) m_image = nn.ZeroPad2d(int(pad_image)) alpha = opt.alpha fixed_noise = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) z_opt = torch.full(fixed_noise.shape, 0, device=opt.device) z_opt = m_noise(z_opt) # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, 0.999)) schedulerD = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerD, milestones=[1600], gamma=opt.gamma) schedulerG = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizerG, milestones=[1600], gamma=opt.gamma) errD2plot = [] errG2plot = [] D_real2plot = [] D_fake2plot = [] z_opt2plot = [] for epoch in range(opt.niter): if (Gs == []) & (opt.mode != 'SR_train'): z_opt = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) z_opt = m_noise(z_opt.expand(1, 3, opt.nzx, opt.nzy)) noise_ = functions.generate_noise([1, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_.expand(1, 3, opt.nzx, opt.nzy)) else: noise_ = functions.generate_noise([opt.nc_z, opt.nzx, opt.nzy], device=opt.device) noise_ = m_noise(noise_) ############################ # (1) Update D network: maximize D(x) + D(G(z)) ########################### for j in range(opt.Dsteps): # train with real netD.zero_grad() output = netD(real).to(opt.device) #D_real_map = output.detach() errD_real = -output.mean() #-a errD_real.backward(retain_graph=True) D_x = -errD_real.item() # train with fake if (j == 0) & (epoch == 0): if (Gs == []) & (opt.mode != 'SR_train'): prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) in_s = prev prev = m_image(prev) z_prev = torch.full([1, opt.nc_z, opt.nzx, opt.nzy], 0, device=opt.device) z_prev = m_noise(z_prev) opt.noise_amp = 1 elif opt.mode == 'SR_train': z_prev = in_s criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) prev = z_prev else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) z_prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rec', m_noise, m_image, opt) criterion = nn.MSELoss() RMSE = torch.sqrt(criterion(real, z_prev)) opt.noise_amp = opt.noise_amp_init * RMSE z_prev = m_image(z_prev) else: prev = draw_concat(Gs, Zs, reals, NoiseAmp, in_s, 'rand', m_noise, m_image, opt) prev = m_image(prev) if opt.mode == 'paint_train': prev = functions.quant2centers(prev, centers) plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) if (Gs == []) & (opt.mode != 'SR_train'): noise = noise_ else: noise = opt.noise_amp * noise_ + prev fake = netG(noise.detach(), prev) output = netD(fake.detach()) errD_fake = output.mean() errD_fake.backward(retain_graph=True) D_G_z = output.mean().item() gradient_penalty = functions.calc_gradient_penalty( netD, real, fake, opt.lambda_grad, opt.device) gradient_penalty.backward() errD = errD_real + errD_fake + gradient_penalty optimizerD.step() errD2plot.append(errD.detach()) ############################ # (2) Update G network: maximize D(G(z)) ########################### for j in range(opt.Gsteps): netG.zero_grad() output = netD(fake) #D_fake_map = output.detach() errG = -output.mean() errG.backward(retain_graph=True) if alpha != 0: loss = nn.MSELoss() if opt.mode == 'paint_train': z_prev = functions.quant2centers(z_prev, centers) plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) Z_opt = opt.noise_amp * z_opt + z_prev rec_loss = alpha * loss(netG(Z_opt.detach(), z_prev), real) rec_loss.backward(retain_graph=True) rec_loss = rec_loss.detach() else: Z_opt = z_opt rec_loss = 0 optimizerG.step() errG2plot.append(errG.detach() + rec_loss) D_real2plot.append(D_x) D_fake2plot.append(D_G_z) z_opt2plot.append(rec_loss) if epoch % 25 == 0 or epoch == (opt.niter - 1): print('scale %d:[%d/%d]' % (len(Gs), epoch, opt.niter)) if epoch % 500 == 0 or epoch == (opt.niter - 1): plt.imsave('%s/fake_sample.png' % (opt.outf), functions.convert_image_np(fake.detach()), vmin=0, vmax=1) plt.imsave('%s/G(z_opt).png' % (opt.outf), functions.convert_image_np( netG(Z_opt.detach(), z_prev).detach()), vmin=0, vmax=1) #plt.imsave('%s/D_fake.png' % (opt.outf), functions.convert_image_np(D_fake_map)) #plt.imsave('%s/D_real.png' % (opt.outf), functions.convert_image_np(D_real_map)) #plt.imsave('%s/z_opt.png' % (opt.outf), functions.convert_image_np(z_opt.detach()), vmin=0, vmax=1) #plt.imsave('%s/prev.png' % (opt.outf), functions.convert_image_np(prev), vmin=0, vmax=1) #plt.imsave('%s/noise.png' % (opt.outf), functions.convert_image_np(noise), vmin=0, vmax=1) #plt.imsave('%s/z_prev.png' % (opt.outf), functions.convert_image_np(z_prev), vmin=0, vmax=1) torch.save(z_opt, '%s/z_opt.pth' % (opt.outf)) schedulerD.step() schedulerG.step() functions.save_networks(netG, netD, z_opt, opt) return z_opt, in_s, netG
def train(self, res_num_epochs, res_percentage, bs): l1 = nn.L1Loss() p = 0 res_percentage = [None] + res_percentage for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)): train_dl = self.train_ds.get_loader(32 * (2**i), cur_bs) val_dl = list(self.val_ds.get_loader(32 * (2**i), 3))[0] train_dl_len = len(train_dl) if(percentage is None): num_epoch_transition = 0 else: num_epoch_transition = int(num_epoch * percentage) cnt = 1 for epoch in range(num_epoch): p = i if(self.resample): train_dl_iter = iter(train_dl) for j, (x, y) in enumerate(tqdm(train_dl)): if(epoch < num_epoch_transition): p = i + cnt / (train_dl_len * num_epoch_transition) - 1 cnt+=1 x = x.to(self.device) y = y.to(self.device) bs = x.size(0) noise = generate_noise(bs, self.nz, self.device) fake_y = self.netG(x, p, noise) self.netD.zero_grad() c_xr = self.netD(x, y) c_xr = c_xr.view(-1) c_xf = self.netD(x, fake_y.detach()) c_xf = c_xf.view(-1) if(self.require_type == 0 or self.require_type == 1): errD = self.loss.d_loss(c_xr, c_xf) elif(self.require_type == 2): errD = self.loss.d_loss(c_xr, c_xf, y, fake_y) if(self.use_gradient_penalty != False): errD += self.use_gradient_penalty * self.gradient_penalty(x, y, fake_y) errD.backward() self.optimizerD.step() if(self.weight_clip != None): for param in self.netD.parameters(): param.data.clamp_(-self.weight_clip, self.weight_clip) self.netG.zero_grad() if(self.resample): x, y = next(train_dl_iter) x = x.to(self.device) y = y.to(self.device) bs = x.size(0) noise = generate_noise(bs, self.nz, self.device) fake_y = self.netG(x, p, noise) if(self.require_type == 0): c_xr = None c_xf, f1 = self.netD(x, fake_y, True) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) errG_1 = self.loss.g_loss(c_xf) if(self.require_type == 1 or self.require_type == 2): c_xr, f2 = self.netD(x, y, True) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) c_xf, f1 = self.netD(x, fake_y, True) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) errG_1 = self.loss.g_loss(c_xr, c_xf) if(self.ds_weight == 0): ds_loss = 0 else: noise1 = generate_noise(bs, self.nz, self.device) noise2 = generate_noise(bs, self.nz, self.device) fake_y1 = self.netG(x, noise1) fake_y2 = self.netG(x, noise2) ds_loss = self.ds_loss.get_loss(fake_y1, fake_y2, noise1, noise2) if(self.rec_weight == 0): rec_loss = 0 else: if(self.use_rec_feature): rec_loss = 0 if(c_xr == None): c_xr, f2 = self.netD(x, y, True) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) for f1_, f2_ in zip(f1, f2): rec_loss += (f1_ - f2_).abs().mean() rec_loss /= len(f1) else: rec_loss = l1(fake_y, y) errG = errG_1 + rec_loss * self.rec_weight + ds_loss * self.ds_weight errG.backward() # update G using the gradients calculated previously self.optimizerG.step() if(j % self.loss_interval == 0): print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %(epoch+1, num_epoch, i+1, train_dl_len, errD, errG)) if(j % self.image_interval == 0): if(self.nz == None): sample_images_list = get_sample_images_list((val_dl, self.netG, p, self.device)) plot_image = get_display_samples(sample_images_list, 3, 3) else: sample_images_list = get_sample_images_list_noise((val_dl, self.netG, p, self.fixed_noise, self.device)) plot_image = get_display_samples(sample_images_list, 9, 3) cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(j)+'.jpg') self.save_cnt += 1 cv2.imwrite(cur_file_name, plot_image)
def train(self, num_epoch): criterion = nn.BCELoss() for epoch in range(num_epoch): for i, data in enumerate(tqdm(self.train_dl)): # (1) : minimize 0.5 * mean((D(x, y) - 1)^2) + 0.5 * mean((D(G(z, y), y) - 0)^2) self.netD.zero_grad() real_images = data[0].to(self.device) real_class = data[1].to(self.device) bs = real_images.size(0) # real labels (bs) real_label = torch.full((bs, ), self.real_label, device=self.device) # fake labels (bs) fake_label = torch.full((bs, ), self.fake_label, device=self.device) # one hot labels (bs, n_classes) one_hot_labels = torch.FloatTensor(bs, self.n_classes).to( self.device) one_hot_labels.zero_() one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0) # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64) noise = generate_noise(bs, self.nz, self.device) fake_class = torch.randint(0, self.n_classes, size=(bs, 1)).view(bs, 1).to(self.device) # one hot labels (bs, n_classes) one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to( self.device) one_hot_labels_fake.zero_() one_hot_labels_fake.scatter_(1, fake_class.view(bs, 1).long(), 1.0) fake_images = self.netG(noise, one_hot_labels_fake) # calculate the discriminator results for both real & fake c_xr = self.netD(real_images, one_hot_labels) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) c_xf = self.netD(fake_images.detach(), one_hot_labels_fake) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) # calculate the discriminator loss errD = criterion(c_xr, real_label) + criterion( c_xf, fake_label) errD.backward() # update D using the gradients calculated previously self.optimizerD.step() # (2) : minimize 0.5 * mean((D(G(z)) - 1)^2) self.netG.zero_grad() if (self.resample): noise = generate_noise(bs, self.nz, self.device) one_hot_labels_fake = torch.FloatTensor( bs, self.n_classes).to(self.device) one_hot_labels_fake.zero_() one_hot_labels_fake.scatter_(1, fake_class.view(bs, 1).long(), 1.0) fake_images = self.netG(noise, one_hot_labels_fake) # we updated the discriminator once, therefore recalculate c_xf c_xf = self.netD(fake_images, one_hot_labels_fake) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) # calculate the Generator loss errG = criterion(c_xf, real_label) # 0.5 * mean((D(G(z)) - 1)^2) errG.backward() #update G using the gradients calculated previously self.optimizerG.step() self.errD_records.append(float(errD)) self.errG_records.append(float(errG)) if (i % self.loss_interval == 0): print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' % (epoch + 1, num_epoch, i + 1, self.train_iteration_per_epoch, errD, errG)) if (i % self.image_interval == 0): sample_images_list = get_sample_images_list( 'Conditional', (self.fixed_noise, self.fixed_one_hot_labels, self.n_classes, self.netG)) plot_fig = plot_multiple_images(sample_images_list, self.n_classes, 1) cur_file_name = os.path.join( self.save_img_dir, str(self.save_cnt) + ' : ' + str(epoch) + '-' + str(i) + '.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf() if (self.snapshot_interval is not None): if (i % self.snapshot_interval == 0): save( os.path.join( self.save_snapshot_dir, 'Epoch' + str(epoch) + '_' + str(i) + '.state'), self.netD, self.netG, self.optimizerD, self.optimizerG)
def train(opt, netG): # Re-generate dataset frames fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt) opt.fps = fps opt.td = td opt.fps_index = fps_index with logger.LoggingBlock("Updating dataset", emph=True): logging.info("{}FPS :{} {}{}".format(green, clear, opt.fps, clear)) logging.info("{}Time-Depth :{} {}{}".format(green, clear, opt.td, clear)) logging.info("{}Sampling-Ratio :{} {}{}".format( green, clear, opt.sampling_rates[opt.fps_index], clear)) opt.dataset.generate_frames(opt.scale_idx) # Initialize noise if not hasattr(opt, 'Z_init_size'): initial_size = utils.get_scales_by_index(0, opt.scale_factor, opt.stop_scale, opt.img_size) initial_size = [int(initial_size * opt.ar), initial_size] opt.Z_init_size = [ opt.batch_size, opt.latent_dim, opt.td, *initial_size ] if opt.vae_levels < opt.scale_idx + 1: D_curr = getattr(networks_3d, opt.discriminator)(opt).to(opt.device) if (opt.netG != '') and (opt.resumed_idx == opt.scale_idx): D_curr.load_state_dict( torch.load('{}/netD_{}.pth'.format( opt.resume_dir, opt.scale_idx - 1))['state_dict']) elif opt.vae_levels < opt.scale_idx: D_curr.load_state_dict( torch.load( '{}/netD_{}.pth'.format(opt.saver.experiment_dir, opt.scale_idx - 1))['state_dict']) # Current optimizers optimizerD = optim.Adam(D_curr.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) parameter_list = [] # Generator Adversary if not opt.train_all: if opt.vae_levels < opt.scale_idx + 1: train_depth = min(opt.train_depth, len(netG.body) - opt.vae_levels + 1) parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-train_depth:])] else: # VAE parameter_list += [{ "params": netG.encode.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }, { "params": netG.decoder.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }] parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] else: if len(netG.body) < opt.train_depth: parameter_list += [{ "params": netG.encode.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }, { "params": netG.decoder.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }] parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body) - 1 - idx)) } for idx, block in enumerate(netG.body)] else: parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] optimizerG = optim.Adam(parameter_list, lr=opt.lr_g, betas=(opt.beta1, 0.999)) # Parallel if opt.device == 'cuda': G_curr = torch.nn.DataParallel(netG) if opt.vae_levels < opt.scale_idx + 1: D_curr = torch.nn.DataParallel(D_curr) else: G_curr = netG progressbar_args = { "iterable": range(opt.niter), "desc": "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1), "train": True, "offset": 0, "logging_on_update": False, "logging_on_close": True, "postfix": True } epoch_iterator = tools.create_progressbar(**progressbar_args) iterator = iter(data_loader) for iteration in epoch_iterator: try: data = next(iterator) except StopIteration: iterator = iter(opt.data_loader) data = next(iterator) if opt.scale_idx > 0: real, real_zero = data real = real.to(opt.device) real_zero = real_zero.to(opt.device) else: real = data.to(opt.device) real_zero = real noise_init = utils.generate_noise(size=opt.Z_init_size, device=opt.device) ############################ # calculate noise_amp ########################### if iteration == 0: if opt.const_amp: opt.Noise_Amps.append(1) else: with torch.no_grad(): if opt.scale_idx == 0: opt.noise_amp = 1 opt.Noise_Amps.append(opt.noise_amp) else: opt.Noise_Amps.append(0) z_reconstruction, _, _ = G_curr(real_zero, opt.Noise_Amps, mode="rec") RMSE = torch.sqrt(F.mse_loss(real, z_reconstruction)) opt.noise_amp = opt.noise_amp_init * RMSE.item( ) / opt.batch_size opt.Noise_Amps[-1] = opt.noise_amp ############################ # (1) Update VAE network ########################### total_loss = 0 generated, generated_vae, (mu, logvar) = G_curr(real_zero, opt.Noise_Amps, mode="rec") if opt.vae_levels >= opt.scale_idx + 1: rec_vae_loss = opt.rec_loss(generated, real) + opt.rec_loss( generated_vae, real_zero) kl_loss = kl_criterion(mu, logvar) vae_loss = opt.rec_weight * rec_vae_loss + opt.kl_weight * kl_loss total_loss += vae_loss else: ############################ # (2) Update D network: maximize D(x) + D(G(z)) ########################### # train with real ################# # Train 3D Discriminator D_curr.zero_grad() output = D_curr(real) errD_real = -output.mean() # train with fake ################# fake, _ = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") # Train 3D Discriminator output = D_curr(fake.detach()) errD_fake = output.mean() gradient_penalty = calc_gradient_penalty(D_curr, real, fake, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty errD_total.backward() optimizerD.step() ############################ # (3) Update G network: maximize D(G(z)) ########################### errG_total = 0 rec_loss = opt.rec_loss(generated, real) errG_total += opt.rec_weight * rec_loss # Train with 3D Discriminator output = D_curr(fake) errG = -output.mean() * opt.disc_loss_weight errG_total += errG total_loss += errG_total G_curr.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(G_curr.parameters(), opt.grad_clip) optimizerG.step() # Update progress bar epoch_iterator.set_description( 'Scale [{}/{}], Iteration [{}/{}]'.format( opt.scale_idx + 1, opt.stop_scale + 1, iteration + 1, opt.niter, )) if opt.visualize: # Tensorboard opt.summary.add_scalar( 'Video/Scale {}/noise_amp'.format(opt.scale_idx), opt.noise_amp, iteration) if opt.vae_levels >= opt.scale_idx + 1: opt.summary.add_scalar( 'Video/Scale {}/KLD'.format(opt.scale_idx), kl_loss.item(), iteration) else: opt.summary.add_scalar( 'Video/Scale {}/rec loss'.format(opt.scale_idx), rec_loss.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/noise_amp'.format(opt.scale_idx), opt.noise_amp, iteration) if opt.vae_levels < opt.scale_idx + 1: opt.summary.add_scalar( 'Video/Scale {}/errG'.format(opt.scale_idx), errG.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/errD_fake'.format(opt.scale_idx), errD_fake.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/errD_real'.format(opt.scale_idx), errD_real.item(), iteration) else: opt.summary.add_scalar( 'Video/Scale {}/Rec VAE'.format(opt.scale_idx), rec_vae_loss.item(), iteration) if iteration % opt.print_interval == 0: with torch.no_grad(): fake_var = [] fake_vae_var = [] for _ in range(3): noise_init = utils.generate_noise(ref=noise_init) fake, fake_vae = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") fake_var.append(fake) fake_vae_var.append(fake_vae) fake_var = torch.cat(fake_var, dim=0) fake_vae_var = torch.cat(fake_vae_var, dim=0) opt.summary.visualize_video(opt, iteration, real, 'Real') opt.summary.visualize_video(opt, iteration, generated, 'Generated') opt.summary.visualize_video(opt, iteration, generated_vae, 'Generated VAE') opt.summary.visualize_video(opt, iteration, fake_var, 'Fake var') opt.summary.visualize_video(opt, iteration, fake_vae_var, 'Fake VAE var') epoch_iterator.close() # Save data opt.saver.save_checkpoint({'data': opt.Noise_Amps}, 'Noise_Amps.pth') opt.saver.save_checkpoint( { 'scale': opt.scale_idx, 'state_dict': netG.state_dict(), 'optimizer': optimizerG.state_dict(), 'noise_amps': opt.Noise_Amps, }, 'netG.pth') if opt.vae_levels < opt.scale_idx + 1: opt.saver.save_checkpoint( { 'scale': opt.scale_idx, 'state_dict': D_curr.module.state_dict() if opt.device == 'cuda' else D_curr.state_dict(), 'optimizer': optimizerD.state_dict(), }, 'netD_{}.pth'.format(opt.scale_idx))
def train(self, num_epoch): for epoch in range(num_epoch): if (self.resample): train_dl_iter = iter(self.train_dl) for i, data in enumerate(tqdm(self.train_dl)): # (1) : minimizes mean((D(x) - mean(D(G(z))) - 1)**2) + mean((D(G(z)) - mean(D(x)) + 1)**2) self.netD.zero_grad() real_images = data[0].to(self.device) bs = real_images.size(0) # real labels (bs) real_label = torch.full((bs, ), self.real_label, device=self.device) # fake labels (bs) fake_label = torch.full((bs, ), self.fake_label, device=self.device) # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64) noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) # calculate the discriminator results for both real & fake c_xr = self.netD(real_images) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) c_xf = self.netD(fake_images.detach()) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) # calculate the Discriminator loss errD = (torch.mean( (c_xr - torch.mean(c_xf) - real_label)**2) + torch.mean( (c_xf - torch.mean(c_xr) + real_label)**2)) / 2.0 errD.backward() # update D using the gradients calculated previously self.optimizerD.step() # (2) : minimizes mean((D(G(z)) - mean(D(x)) - 1)**2) + mean((D(x) - mean(D(G(z))) + 1)**2) self.netG.zero_grad() if (self.resample): real_images = next(train_dl_iter)[0].to(self.device) noise = generate_noise(bs, self.nz, self.device) fake_images = self.netG(noise) # we updated the discriminator once, therefore recalculate c_xr, c_xf c_xr = self.netD(real_images) # (bs, 1, 1, 1) c_xr = c_xr.view(-1) # (bs) c_xf = self.netD(fake_images) # (bs, 1, 1, 1) c_xf = c_xf.view(-1) # (bs) # calculate the Generator loss errG = (torch.mean( (c_xf - torch.mean(c_xr) - real_label)**2) + torch.mean( (c_xr - torch.mean(c_xf) + real_label)**2)) / 2.0 errG.backward() # update G using the gradients calculated previously self.optimizerG.step() self.errD_records.append(float(errD)) self.errG_records.append(float(errG)) if (i % self.loss_interval == 0): print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' % (epoch + 1, num_epoch, i + 1, self.train_iteration_per_epoch, errD, errG)) if (i % self.image_interval == 0): if (self.special == None): sample_images_list = get_sample_images_list( 'Unsupervised', (self.fixed_noise, self.netG)) plot_fig = plot_multiple_images( sample_images_list, 4, 4) cur_file_name = os.path.join( self.save_img_dir, str(self.save_cnt) + ' : ' + str(epoch) + '-' + str(i) + '.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf() elif (self.special == 'Wave'): sample_audios_list = get_sample_images_list( 'Unsupervised_Audio', (self.fixed_noise, self.netG)) plot_fig = plot_multiple_spectrograms( sample_audios_list, 4, 4, freq=16000) cur_file_name = os.path.join( self.save_img_dir, str(self.save_cnt) + ' : ' + str(epoch) + '-' + str(i) + '.jpg') self.save_cnt += 1 save_fig(cur_file_name, plot_fig) plot_fig.clf() if (self.snapshot_interval is not None): if (i % self.snapshot_interval == 0): save( os.path.join( self.save_snapshot_dir, 'Epoch' + str(epoch) + '_' + str(i) + '.state'), self.netD, self.netG, self.optimizerD, self.optimizerG)
def train(opt): images = creat_reals_pyramid(opt) place = fluid.CUDAPlace(0) if opt.use_gpu else fluid.CPUPlace() priors = [] prior_recons = [] netD_arrs = [] netG_arrs = [] noiseamp_arrs = [] opt.padd_size = 1 for idx in range(0, len(images)): outdir = "%s/%d/" % (opt.out, idx) if not os.path.isdir(outdir): os.mkdir(outdir) with fluid.dygraph.guard(): real = images[idx] in_s = np.zeros(shape=real.shape, dtype=np.float32) zero = fluid.layers.zeros(shape=[1], dtype='float32') #zero.stop_gradient = True one = fluid.layers.ones(shape=[1], dtype='float32') #one.stop_gradient = True alpha = to_variable(np.array([opt.alpha]).astype('float32')) optimizerG = fluid.optimizer.Adam(learning_rate=opt.lr_d, beta1=opt.beta1, beta2=0.999, name='net_GA') optimizerD = fluid.optimizer.Adam(learning_rate=opt.lr_d, beta1=opt.beta1, beta2=0.999, name='net_DA') backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = True #optimizerD = fluid.optimizer.RMSPropOptimizer(learning_rate=opt.lr_d, name="opD") #optimizerG = fluid.optimizer.RMSPropOptimizer(learning_rate=opt.lr_g, name="opG") #fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(0,1)) netD = Discriminator("DA", opt) netG = Generator("GA", opt) # fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(min=-0.01, max=0.01),param_list=[netD.parameters(),netG.parameters()]) vreal = to_variable(real) for epoch in range(opt.niter): noise_epoch = generate_noise(real.shape, opt) prev = in_s prev_rec = in_s opt.noise_amp = 1 for idx in range(len(netG_arrs)): prev = priors[idx] prev_rec = prior_recons[idx] opt.noise_amp = noiseamp_arrs[idx] prev = resize(prev, (real.shape[3], real.shape[2])) prev_rec = resize(prev_rec, (real.shape[3], real.shape[2])) vprev = to_variable(prev) vprev_rec = to_variable(prev_rec) for j in range(opt.Dsteps): netD.clear_gradients() outD_real = netD(vreal) errD_real = fluid.layers.mean(outD_real) errD_real = 0.0 - errD_real errD_real.backward(backward_strategy) #errD_real = fluid.layers.elementwise_sub(zero, errD_real) # errD_real.backward(backward_strategy) noise = opt.noise_amp * noise_epoch + prev vnoise = to_variable(noise) outG_fake = netG(vnoise.detach(), vprev) outD_fake = netD(outG_fake.detach()) errD_fake = fluid.layers.mean(outD_fake) # errD_fake = fluid.layers.elementwise_sub(zero, errD_fake) errD_fake.backward(backward_strategy) #gradient_penalty = calc_gradient_penalty(netD, vreal, outG_fake, opt, backward_strategy) #gradient_penalty.backward() #errD = errD_real + errD_fake + gradient_penalty errD = errD_real + errD_fake params_d = optimizerD.backward( errD, parameter_list=netD.parameters()) optimizerD.apply_gradients(params_d) for j in range(opt.Gsteps): netG.clear_gradients() outD_fakeG = netD(outG_fake) errD_fakeG = 0.0 - fluid.layers.mean(outD_fakeG) # errD_fakeG = fluid.layers.elementwise_add(zero, errD_fakeG) errD_fakeG.backward(backward_strategy) noise_fake = opt.noise_amp * noise_epoch + prev_rec noise_fake = to_variable(noise_fake) outG_fake_rec = netG(noise_fake.detach(), vprev_rec) rec_loss = fluid.layers.mse_loss(vreal, outG_fake_rec) RMSE = fluid.layers.sqrt(rec_loss).numpy() rec_loss = fluid.layers.elementwise_mul(alpha, rec_loss) rec_loss.backward(backward_strategy) errG = rec_loss + errD_fakeG #errG = errD_fakeG params_g = optimizerG.backward( errG, parameter_list=netG.parameters()) optimizerG.apply_gradients(params_g) #netD.clear_gradients() #netG.clear_gradients() if epoch % 25 == 0 or epoch == (opt.niter - 1): print( 'shape %s [epoch:%d/%d][errD:%.5f][errG:%.5f][rec_loss:%.5f][noise_amp:%.5f][errD_real:%.5f][errD_fake:%.5f][outD_fakeG:%.5f]' % (real.shape, epoch, opt.niter, errD.numpy(), errG.numpy(), rec_loss.numpy(), opt.noise_amp, errD_real.numpy(), errD_fake.numpy(), errD_fakeG.numpy())) if epoch % 500 == 0 or epoch == (opt.niter - 1): dump_img( outG_fake.numpy(), os.path.join( outdir, "fake_sample_%d_%s" % (epoch, opt.input_name))) dump_img( outG_fake_rec.numpy(), os.path.join( outdir, "G(z_opt)_%d_%s" % (epoch, opt.input_name))) fluid.dygraph.save_dygraph(netD.state_dict(), os.path.join(outdir, "DA")) fluid.dygraph.save_dygraph(netG.state_dict(), os.path.join(outdir, "GA")) fluid.dygraph.save_dygraph(optimizerD.state_dict(), os.path.join(outdir, "DA")) fluid.dygraph.save_dygraph(optimizerG.state_dict(), os.path.join(outdir, "GA")) opt.noise_amp = opt.noise_amp_init * RMSE netD.eval() netD_arrs.append(netD) netG.eval() netG_arrs.append(netG) priors.append(outG_fake.numpy()) prior_recons.append(outG_fake_rec.numpy()) noiseamp_arrs.append(opt.noise_amp)