def training_step_discrim(discriminator, optimizer, images, generated_images): '''Calculates the network loss of the discriminator and backpropagates it through the network and updates it. Args: Discriminator, the respective optimizer, real images, fake images Returns: Loss of the discriminator ''' # calculate the discriminator loss and apply gradients with tf.GradientTape() as tape: # feed real images into discriminator, get the predictions real_image_predictions = discriminator(images) # feed fake images into discriminator, get the predictions fake_image_predictions = discriminator(generated_images) #calculate adversarial loss discr_loss = losses.discriminator_loss(fake_image_predictions, real_image_predictions) gradients = tape.gradient(discr_loss, discriminator.trainable_variables) optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables)) return discr_loss
def discriminator_train_step(self, batch): gen_model = self.models_dict['gen'] dis_model = self.models_dict['dis'] dis_prev_images = batch["prev_image"] dis_next_images = batch["next_image"] dis_prev_images_gt = batch["prev_image_gt"] dis_next_images_gt = batch["next_image_gt"] dis_images = torch.cat([dis_prev_images, dis_next_images], 1) dis_images_gt = torch.cat([dis_prev_images_gt, dis_next_images_gt], 1) dis_event_volume = batch["event_volume"] split_size = dis_images.shape[0] // 2 # inference with generator # optimize the discriminator # First half of discriminator data is for fake. # Second half of discriminator data is for real. dis_images_fake = dis_images[:split_size, ...] dis_images_real = dis_images[split_size:, ...] dis_images_fake_gt = dis_images_gt[:split_size, ...] dis_images_real_gt = dis_images_gt[split_size:, ...] event_volume_real = dis_event_volume[split_size:, ...] dis_gen_fake_volume = gen_model(dis_images_fake) fake_logits = dis_model([dis_gen_fake_volume[-1]], dis_images_fake_gt) real_logits = dis_model([event_volume_real], dis_images_real_gt) # If rand_n is less than flip_label, flip labels. rand_n = np.random.random(real_logits.shape) do_flip = torch.from_numpy( np.greater(rand_n, self.options.flip_label).astype(np.uint8)).to( self.device).byte() real_logits_maybe_flipped = torch.where(do_flip, real_logits, fake_logits) fake_logits_maybe_flipped = torch.where(do_flip, fake_logits, real_logits) d_loss = discriminator_loss("hinge", real_logits_maybe_flipped, fake_logits_maybe_flipped) losses = { 'discriminator': d_loss, 'fake_class_acc': torch.mean( torch.eq( torch.sign(fake_logits), -torch.ones(fake_logits.size()).to(self.device)).float()), 'real_class_acc': torch.mean( torch.eq(torch.sign(real_logits), torch.ones(real_logits.size()).to( self.device)).float()) } outputs = {} return d_loss, losses, outputs
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config.batch_size) y = torch.split(y, config.batch_size) counter = 0 # Optionally toggle D and G's "require_grad" toggle_grad(D, True) toggle_grad(G, False) for step_index in range(config.num_D_steps): z_.sample_() y_.sample_() D_fake, D_real = GD(z_[:config.batch_size], y_[:config.batch_size], x[counter], y[counter], train_G=False) D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) D_loss.backward() # counter += 1 D.optim.step() # Optionally toggle "requires_grad" toggle_grad(D, False) toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() z_.sample_() y_.sample_() D_fake = GD(z_, y_, train_G=True) G_loss = losses.generator_loss(D_fake) G_loss.backward() G.optim.step() out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()) } # Return G's loss and the components of D's loss. return out
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() # The fake class label lossy = torch.LongTensor(config['batch_size']) lossy = lossy.cuda() lossy.data.fill_( config['n_classes']) # index for fake just for loss for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations if config['mh_csc_loss'] or config['mh_loss']: D_loss_real = losses.crammer_singer_criterion( D_real, y[counter]) D_loss_fake = losses.crammer_singer_criterion( D_fake, lossy[:config['batch_size']]) else: D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / float( config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): # reusing the same noise for CIFAR ... if config['resampling'] or (accumulation_index > 0): z_.sample_() y_.sample_() if config['fm_loss']: D_feat_fake, D_feat_real = GD(z_, y_, x[-1], None, train_G=True, split_D=config['split_D'], feat=True) fm_loss = torch.mean( torch.abs( torch.mean(D_feat_fake, 0) - torch.mean(D_feat_real, 0))) G_loss = fm_loss else: D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) if config['mh_csc_loss']: G_loss = losses.crammer_singer_complement_criterion( D_fake, lossy[:config['batch_size']]) / float( config['num_G_accumulations']) elif config['mh_loss']: D_feat_fake, D_feat_real = GD(z_, y_, x[-1], None, train_G=True, split_D=config['split_D'], feat=True) fm_loss = torch.mean( torch.abs( torch.mean(D_feat_fake, 0) - torch.mean(D_feat_real, 0))) oth_loss = losses.mh_loss(D_fake, y_[:config['batch_size']]) G_loss = (config['mh_fmloss_weight'] * fm_loss + config['mh_loss_weight'] * oth_loss) / float( config['num_G_accumulations']) else: G_loss = losses.generator_loss(D_fake) / float( config['num_G_accumulations']) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G' ) # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()) } # Return G's loss and the components of D's loss. return out
def train(x, y, stage): G.optim.zero_grad() D.optim.zero_grad() M.optim.zero_grad() # yaxing # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: # yaxing: hert it is True utils.toggle_grad(D, True) utils.toggle_grad(G, False) utils.toggle_grad(M, False) # yaxing for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() # yaxing: set gy and dy is equal 0, since we donot know label D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / float( config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # yaxing: hert it is 0.0 # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) if stage == 1: utils.toggle_grad(G, False) # yaxing else: utils.toggle_grad(G, True) # yaxing utils.toggle_grad(M, True) # yaxing # Zero G's gradients by default before training G, for safety G.optim.zero_grad() M.optim.zero_grad() # yaxing # If accumulating gradients, loop multiple times for accumulation_index in range( config['num_G_accumulations']): # yaxing: hert it is 1 z_.sample_() y_.sample_() #D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) # yaxing: set gy and dy is equal 0, since we donot know label D_fake, M_regu = GD(z_, y_, train_G=True, split_D=config['split_D'], train_M=True, M_regu=True) #G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) M_loss = losses.generator_loss(D_fake, M_regu) / float( config['num_G_accumulations']) #pdb.set_trace() #G_loss.backward() M_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # yaxing: hert it is 0.0 print('using modified ortho reg in G' ) # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) if stage == 2: G.optim.step() M.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) #out = {'G_loss': float(G_loss.item()), out = { 'G_loss': float(M_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()) } # Return G's loss and the components of D's loss. return out
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_scores = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, policy=config['DiffAugment'], CR=config['CR'] > 0, CR_augment=config['CR_augment']) D_loss_CR = 0 if config['CR'] > 0: # to do continue else: D_fake, D_real = D_scores # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = D_loss_real + D_loss_fake + D_loss_CR D_loss = D_loss / float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() if not config['fix_G']: # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake = GD(z_, y_, train_G=True, policy=config['DiffAugment']) G_loss = losses.generator_loss(D_fake) / float( config['num_G_accumulations']) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho( G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(G_loss.item()) if not config['fix_G'] else 0, 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), } if config['CR'] > 0: out['D_loss_CR'] = float(D_loss_CR.item()) # Return G's loss and the components of D's loss. return out
def train(x_s, y, yd): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? y = y.long() yd = yd.long() x_s = torch.split(x_s, config['batch_size']) y = torch.split(y, config['batch_size']) yd = torch.split(yd, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() yd_.sample_() D_fake, D_real, mi, c_cls, mid, c_clsd, G_z = GD( z_, y_, yd_, x_s[counter], y[counter], yd[counter], train_G=False, split_D=config['split_D'], return_G_z=True) D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) C_loss = 0 if config['AC']: fake_mi = mi[:D_fake.shape[0]] fake_cls = c_cls[:D_fake.shape[0]] c_cls_rs = c_cls[D_fake.shape[0]:] fake_mid = mid[:D_fake.shape[0]] c_clsd = c_clsd[D_fake.shape[0]:] # print(yd) # print(yd_) if config['loss_type'] == 'Twin_AC': C_loss += F.cross_entropy(c_clsd, yd[counter]) + F.cross_entropy(fake_mid, yd_) + \ 0.5*F.cross_entropy(c_cls_rs[yd[counter]!=0], y[counter][yd[counter]!=0]) + 0.5*F.cross_entropy(fake_cls, y_) + 1.0*F.cross_entropy(fake_mi, y_) # if state_dict['itr'] > 0000: # C_loss += 0.2*F.cross_entropy(c_cls_ft, y_[yd_!=0]) + 0.2*F.cross_entropy(fake_mi_t[yd_!=0], y_[yd_!=0])#F.cross_entropy(fake_mi[yd_ == 0], y_[yd_ == 0]) if config['loss_type'] == 'AC': C_loss += F.cross_entropy( c_cls_fs, y_f_s) + F.cross_entropy(c_clsd, yd) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations if config['Pac']: x_pack = torch.cat([x_s[counter], x_t[counter]], dim=0) T_img = x_pack.view(-1, 4 * x_pack.size()[1], x_pack.size()[2], x_pack.size()[3]) F_img = G_z.view(-1, 4 * G_z.size()[1], G_z.size()[2], G_z.size()[3]) pack_img = torch.cat([T_img, F_img], dim=0) pack_out, _, _ = D(pack_img, pack=True) D_real_pac = pack_out[:T_img.size()[0]] D_fake_pac = pack_out[T_img.size()[0]:] D_loss_real_pac, D_loss_fake_pac = losses.discriminator_loss( D_fake_pac, D_real_pac) D_loss_real += D_loss_real_pac D_loss_fake += D_loss_fake_pac D_loss = (D_loss_real + D_loss_fake + C_loss * config['AC_weight']) / float( config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() for step_index in range(config['num_G_steps']): for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() yd_.sample_() D_fake, mi, cls, mid, clsd, G_z = GD(z_, y_, yd_, train_G=True, split_D=config['split_D'], return_G_z=True) C_loss = 0 MI_loss = 0 CD_loss = 0 MID_loss = 0 G_loss = losses.generator_loss(D_fake) if config['loss_type'] == 'AC' or config[ 'loss_type'] == 'Twin_AC': C_loss = 1.0 * F.cross_entropy( cls, y_) #+ 0.5*F.cross_entropy(cls[yd_!=0], y_[yd_!=0]) CD_loss = F.cross_entropy(clsd, yd_) if config['loss_type'] == 'Twin_AC': MI_loss = 1.0 * F.cross_entropy(mi, y_) # if state_dict['itr'] > 0000: # MI_loss += 0.5*F.cross_entropy(mi_t[yd_!=0], y_[yd_!=0]) MID_loss = F.cross_entropy(mid, yd_) if config['Pac']: F_img = G_z.view(-1, 4 * G_z.size()[1], G_z.size()[2], G_z.size()[3]) D_fake_pac, _, _ = D(F_img, pack=True) G_loss_pac = losses.generator_loss(D_fake_pac) G_loss += G_loss_pac G_loss = G_loss / float(config['num_G_accumulations']) C_loss = C_loss / float(config['num_G_accumulations']) MI_loss = MI_loss / float(config['num_G_accumulations']) CD_loss = CD_loss / float(config['num_G_accumulations']) MID_loss = MID_loss / float(config['num_G_accumulations']) (G_loss + (C_loss - MI_loss + CD_loss - MID_loss) * config['AC_weight']).backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G' ) # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho( G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), 'C_loss': C_loss, 'MI_loss': MI_loss, 'CD_loss': CD_loss, 'MID_loss': MID_loss } # Return G's loss and the components of D's loss. return out
def build_model(self): self.build_data_loader() x_lr, x_hr = self.inputs g_fake = self.generator(x_lr) # PatchGAN-wise d_fake = self.discriminator(g_fake) d_real = self.discriminator(x_hr, reuse=True) # losses self.d_adv_loss = discriminator_loss(self.gan_type, d_real, d_fake, use_ra=self.use_ra) self.g_adv_loss = generator_loss(self.gan_type, d_real, d_fake, use_ra=self.use_ra) gp = 0. if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan': gp = self.gradient_penalty(real=x_hr, fake=g_fake) self.d_loss = self.d_adv_loss + gp self.rec_loss = tf.reduce_mean(tf.abs(g_fake - x_hr)) self.g_loss = self.weight_adv_loss * self.g_adv_loss + self.weight_rec_loss * self.rec_loss if self.use_perceptual_loss: x_real = tf.image.resize_images(x_hr, size=(224, 224), align_corners=False) x_fake = tf.image.resize_images(g_fake, size=(224, 224), align_corners=False) vgg19_real = self.build_vgg19_model(x_real) vgg19_fake = self.build_vgg19_model(x_fake, reuse=True) self.perceptual_loss = tf.reduce_mean(tf.square(vgg19_real - vgg19_fake)) self.g_loss += self.weight_perceptual_loss * self.perceptual_loss t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if "discriminator" in var.name] g_vars = [var for var in t_vars if "generator" in var.name] with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): d_opt = tf.train.AdamOptimizer(self.d_lr, beta1=self.beta1, beta2=self.beta2) d_grads, d_vars = zip(*d_opt.compute_gradients(self.d_loss, var_list=d_vars)) d_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in d_grads] self.d_opt = d_opt.apply_gradients(zip(d_grads, d_vars)) g_opt = tf.train.AdamOptimizer(self.g_lr, beta1=self.beta1, beta2=self.beta2) g_grads, g_vars = zip(*g_opt.compute_gradients(self.g_loss, var_list=g_vars)) g_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in g_grads] self.g_opt = g_opt.apply_gradients(zip(g_grads, g_vars)) g_rec_opt = tf.train.AdamOptimizer(self.g_lr, beta1=self.beta1, beta2=self.beta2) g_rec_grads, g_rec_vars = zip(*g_rec_opt.compute_gradients(self.rec_loss, var_list=g_vars)) g_rec_grads = [tf.clip_by_norm(grad, self.grad_clip_norm) for grad in g_rec_grads] self.g_rec_opt = g_rec_opt.apply_gradients(zip(g_rec_grads, g_rec_vars)) # summaries tf.summary.scalar("loss/d_adv_loss", self.d_adv_loss) tf.summary.scalar("loss/g_adv_loss", self.g_adv_loss) tf.summary.scalar("loss/rec_loss", self.rec_loss) tf.summary.scalar("loss/g_loss", self.g_loss) if self.use_perceptual_loss: tf.summary.scalar("loss/perceptual_loss", self.perceptual_loss) tf.summary.image("real/x_lr", x_lr, max_outputs=1) tf.summary.image("real/x_hr", x_hr, max_outputs=1) tf.summary.image("fake/gen", g_fake, max_outputs=1) self.merged = tf.summary.merge_all() self.saver = tf.train.Saver(max_to_keep=5) self.best_saver = tf.train.Saver(max_to_keep=1) self.writer = tf.summary.FileWriter(self.checkpoint_dir, self.sess.graph)
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 lambda_D = config['lambda_D'] lambda_G = config['lambda_G'] # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_scores, D_scores_rotate90, D_scores_rotate180, D_scores_rotate270, \ D_scores_croptl, D_scores_croptr, D_scores_cropbl, D_scores_cropbr, \ D_scores_translation, D_scores_cutout = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, policy=config['DiffAugment'], CR=config['CR'] > 0, CR_augment=config['CR_augment']) D_loss_CR = 0 if config['CR'] > 0: D_fake, D_real, D_real_aug = D_scores D_loss_CR = torch.mean( (D_real_aug - D_real)**2) * config['CR'] else: D_fake, D_real = D_scores # rotation D_fake_rotate90, D_real_rotate90 = D_scores_rotate90 D_fake_rotate180, D_real_rotate180 = D_scores_rotate180 D_fake_rotate270, D_real_rotate270 = D_scores_rotate270 # cropping D_fake_croptl, D_real_croptl = D_scores_croptl D_fake_croptr, D_real_croptr = D_scores_croptr D_fake_cropbl, D_real_cropbl = D_scores_cropbl D_fake_cropbr, D_real_cropbr = D_scores_cropbr # translation & cutout D_fake_translation, D_real_translation = D_scores_translation D_fake_cutout, D_real_cutout = D_scores_cutout # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) # rotation D_loss_real_rotate90, D_loss_fake_rotate90 = losses.discriminator_loss( D_fake_rotate90, D_real_rotate90) D_loss_real_rotate180, D_loss_fake_rotate180 = losses.discriminator_loss( D_fake_rotate180, D_real_rotate180) D_loss_real_rotate270, D_loss_fake_rotate270 = losses.discriminator_loss( D_fake_rotate270, D_real_rotate270) # croping D_loss_real_croptl, D_loss_fake_croptl = losses.discriminator_loss( D_fake_croptl, D_real_croptl) D_loss_real_croptr, D_loss_fake_croptr = losses.discriminator_loss( D_fake_croptr, D_real_croptr) D_loss_real_cropbl, D_loss_fake_cropbl = losses.discriminator_loss( D_fake_cropbl, D_real_cropbl) D_loss_real_cropbr, D_loss_fake_cropbr = losses.discriminator_loss( D_fake_cropbr, D_real_cropbr) # translation and cutout D_loss_real_translation, D_loss_fake_translation = losses.discriminator_loss( D_fake_translation, D_real_translation) D_loss_real_cutout, D_loss_fake_cutout = losses.discriminator_loss( D_fake_cutout, D_real_cutout) D_loss = D_loss_real + D_loss_fake + D_loss_CR # rotation D_loss_rotate90 = D_loss_real_rotate90 + D_loss_fake_rotate90 D_loss_rotate180 = D_loss_real_rotate180 + D_loss_fake_rotate180 D_loss_rotate270 = D_loss_real_rotate270 + D_loss_fake_rotate270 # cropping D_loss_croptl = D_loss_real_croptl + D_loss_fake_croptl D_loss_croptr = D_loss_real_croptr + D_loss_fake_croptr D_loss_cropbl = D_loss_real_cropbl + D_loss_fake_cropbl D_loss_cropbr = D_loss_real_cropbr + D_loss_fake_cropbr # translation and cutout D_loss_translation = D_loss_real_translation + D_loss_fake_translation D_loss_cutout = D_loss_real_cutout + D_loss_fake_cutout D_loss = D_loss + lambda_D/4*(D_loss + D_loss_rotate90 + D_loss_rotate180 + D_loss_rotate270) \ + lambda_D/5*(D_loss + D_loss_croptl + D_loss_croptr + D_loss_cropbl + D_loss_cropbr) \ + lambda_D/2*(D_loss + D_loss_translation) \ + lambda_D/2*(D_loss + D_loss_cutout) D_loss = D_loss / float(config['num_D_accumulations']) D_loss.backward(retain_graph=True) counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() if not config['fix_G']: # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake, D_fake_rotate90, D_fake_rotate180, D_fake_rotate270, \ D_fake_croptl, D_fake_croptr, D_fake_cropbl, D_fake_cropbr, D_fake_translation, D_fake_cutout = GD(z_, y_, train_G=True, policy=config['DiffAugment']) G_loss_rotate0 = losses.generator_loss(D_fake) / float( config['num_G_accumulations']) # rotation G_loss_rotate90 = losses.generator_loss( D_fake_rotate90) / float(config['num_G_accumulations']) G_loss_rotate180 = losses.generator_loss( D_fake_rotate180) / float(config['num_G_accumulations']) G_loss_rotate270 = losses.generator_loss( D_fake_rotate270) / float(config['num_G_accumulations']) # cropping G_loss_croptl = losses.generator_loss(D_fake_croptl) / float( config['num_G_accumulations']) G_loss_croptr = losses.generator_loss(D_fake_croptr) / float( config['num_G_accumulations']) G_loss_cropbl = losses.generator_loss(D_fake_cropbl) / float( config['num_G_accumulations']) G_loss_cropbr = losses.generator_loss(D_fake_cropbr) / float( config['num_G_accumulations']) # translation and cutout G_loss_translation = losses.generator_loss( D_fake_translation) / float(config['num_G_accumulations']) G_loss_cutout = losses.generator_loss(D_fake_cutout) / float( config['num_G_accumulations']) G_loss = G_loss_rotate0 + lambda_G/4.*(G_loss_rotate0 + G_loss_rotate90 + G_loss_rotate180 + G_loss_rotate270) \ + lambda_G/5.*(G_loss_rotate0 + G_loss_croptl + G_loss_croptr + G_loss_cropbl + G_loss_cropbr) \ + lambda_G/2.*(G_loss_rotate0 + G_loss_translation) \ + lambda_G/2.*(G_loss_rotate0 + G_loss_cutout) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho( G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(G_loss.item()) if not config['fix_G'] else 0, 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), } if config['CR'] > 0: out['D_loss_CR'] = float(D_loss_CR.item()) # Return G's loss and the components of D's loss. return out
def train(data, epochs, batch_size=1, gen_lr=5e-6, disc_lr=5e-7, epoch_offset=0): generator = Generator(input_shape=[None,None,2]) discriminator = Discriminator(input_shape=[None,None,1]) generator_optimizer = tf.keras.optimizers.Adam(gen_lr) discriminator_optimizer = tf.keras.optimizers.Adam(disc_lr) model_name = data['training'].origin+'_2_any' checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name) if(not os.path.isdir(checkpoint_prefix)): os.makedirs(checkpoint_prefix) else: if(os.path.isfile(os.path.join(checkpoint_prefix, 'generator.h5'))): generator.load_weights(os.path.join(checkpoint_prefix, 'generator.h5'), by_name=True) print('Generator weights restorred from ' + checkpoint_prefix) if(os.path.isfile(os.path.join(checkpoint_prefix, 'discriminator.h5'))): discriminator.load_weights(os.path.join(checkpoint_prefix, 'discriminator.h5'), by_name=True) print('Discriminator weights restorred from ' + checkpoint_prefix) # Get the number of batches in the training set epoch_size = data['training'].__len__() print() print("Started training with the following parameters: ") print("\tCheckpoints: \t", checkpoint_prefix) print("\tEpochs: \t", epochs) print("\tgen_lr: \t", gen_lr) print("\tdisc_lr: \t", disc_lr) print("\tBatchSize: \t", batch_size) print("\tnBatches: \t", epoch_size) print() # Precompute the test input and target for validation audio_input = load_audio(os.path.join(TEST_AUDIOS_PATH, data['training'].origin+'.wav')) mag_input, phase = forward_transform(audio_input) mag_input = amplitude_to_db(mag_input) test_input = slice_magnitude(mag_input, mag_input.shape[0]) test_input = (test_input * 2) - 1 test_inputs = [] test_targets = [] for t in data['training'].target: audio_target = load_audio(os.path.join(TEST_AUDIOS_PATH, t+'.wav')) mag_target, _ = forward_transform(audio_target) mag_target = amplitude_to_db(mag_target) test_target = slice_magnitude(mag_target, mag_target.shape[0]) test_target = (test_target * 2) - 1 test_target_perm = test_target[np.random.permutation(test_target.shape[0]),:,:,:] test_inputs.append(np.concatenate([test_input, test_target_perm], axis=3)) test_targets.append(test_target) gen_mae_list, gen_mae_val_list = [], [] gen_loss_list, gen_loss_val_list = [], [] disc_loss_list, disc_loss_val_list = [], [] for epoch in range(epochs): gen_mae_total, gen_mae_val_total = 0, 0 gen_loss_total, gen_loss_val_total = 0, 0 disc_loss_total, disc_loss_val_total = 0, 0 print('Epoch {}/{}'.format((epoch+1)+epoch_offset, epochs+epoch_offset)) progbar = tf.keras.utils.Progbar(epoch_size) for i in range(epoch_size): # Get the data from the DataGenerator input_image, target = data['training'].__getitem__(i) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Generate a fake image gen_output = generator(input_image, training=True) # Train the discriminator disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True) disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True) # Compute the losses gen_mae = l1_loss(target, gen_output) gen_loss = generator_loss(disc_generated_output, gen_mae) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) # Compute the gradients generator_gradients = gen_tape.gradient(gen_loss,generator.trainable_variables) discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # Apply the gradients generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) # Update the progress bar gen_mae = gen_mae.numpy() gen_loss = gen_loss.numpy() disc_loss = disc_loss.numpy() gen_mae_total += gen_mae gen_loss_total += gen_loss disc_loss_total += disc_loss progbar.add(1, values=[ ("gen_mae", gen_mae), ("gen_loss", gen_loss), ("disc_loss", disc_loss) ]) gen_mae_list.append(gen_mae_total/epoch_size) gen_mae_val_list.append(gen_mae_val_total/epoch_size) gen_loss_list.append(gen_loss_total/epoch_size) gen_loss_val_list.append(gen_loss_val_total/epoch_size) disc_loss_list.append(disc_loss_total/epoch_size) disc_loss_val_list.append(disc_loss_val_total/epoch_size) history = pd.DataFrame({ 'gen_mae': gen_mae_list, 'gen_mae_val': gen_mae_val_list, 'gen_loss': gen_loss_list, 'gen_loss_val': gen_loss_val_list, 'disc_loss': disc_loss_list, 'disc_loss_val': disc_loss_val_list }) write_csv(history, os.path.join(checkpoint_prefix, 'history.csv')) epoch_output = os.path.join(OUTPUT_PATH, model_name, str((epoch+1)+epoch_offset).zfill(3)) init_directory(epoch_output) # Generate audios and save spectrograms for the entire audios for j in range(len(data['training'].target)): prediction = generator(test_inputs[j], training=False) prediction = (prediction + 1) / 2 generate_images(prediction, (test_inputs[j] + 1) / 2, (test_targets[j] + 1) / 2, os.path.join(epoch_output, 'spectrogram_'+data['training'].target[j])) generate_audio(prediction, phase, os.path.join(epoch_output, 'audio_'+data['training'].target[j]+'.wav')) print('Epoch outputs saved in ' + epoch_output) # Save the weights generator.save_weights(os.path.join(checkpoint_prefix, 'generator.h5')) discriminator.save_weights(os.path.join(checkpoint_prefix, 'discriminator.h5')) print('Weights saved in ' + checkpoint_prefix) # Callback at the end of the epoch for the DataGenerator data['training'].on_epoch_end()
def create_train_op(input, labels, params): assert labels is None reals, reals_class_id = input['reals'] pp(['input', input]) pp(['reals', reals]) pp(['reals_class_id', reals_class_id]) pp(['params', params]) mdl = BigGAN.GAN() BigGAN.instance = mdl dim_z = mdl.gan.generator.dim_z nclasses = mdl.gan.discriminator.n_class N, H, W, C = reals.shape.as_list() fakes_z, fakes_class_id = utils.prepare_z_y(G_batch_size=N, dim_z=dim_z, nclasses=nclasses) reals_y = tf.one_hot(reals_class_id, nclasses) fakes_y = tf.one_hot(fakes_class_id, nclasses) fakes = mdl.gan.generator(fakes_z, fakes_y) reals_D = mdl.gan.discriminator(reals, reals_y) fakes_D = mdl.gan.discriminator(fakes, fakes_y) global_step = tflex.get_or_create_global_step() #inc_global_step = global_step.assign_add(1, read_value=False, name="inc_global_step") # G_vars = [] # D_vars = [] # for variable in tf.trainable_variables(): # if variable.name.startswith('Generator/'): # G_vars.append(variable) # elif variable.name.startswith('Discriminator/'): # D_vars.append(variable) # elif variable.name.startswith('linear/w'): # G_vars.append(variable) # D_vars.append(variable) # else: # import pdb; pdb.set_trace() # assert False, "Unexpected trainable variable" T_vars = tf.trainable_variables() G_vars = [ x for x in T_vars if x.name.startswith('Generator/') or x.name.startswith('linear/w:') ] D_vars = [ x for x in T_vars if x.name.startswith('Discriminator/') or x.name.startswith('linear/w:') ] leftover_vars = [ x for x in T_vars if x not in G_vars and x not in D_vars ] if len(leftover_vars) > 0: import pdb pdb.set_trace() raise ValueError("Unexpected trainable variables") # pp({ # "G_vars": G_vars, # "D_vars": D_vars, # "leftover_vars": leftover_vars, # }) if True: def should_train_variable(v): return True train_vars = [ v for v in tf.trainable_variables() if should_train_variable(v) ] non_train_vars = [ v for v in tf.trainable_variables() if not should_train_variable(v) ] other_vars = [ v for v in tf.global_variables() if v not in train_vars and v not in non_train_vars ] local_vars = [v for v in tf.local_variables()] paramcount = lambda vs: sum( [np.prod(v.shape.as_list()) for v in vs]) def logvars(variables, label, print_variables=False): if print_variables: tf.logging.info("%s (%s parameters): %s", label, paramcount(variables), pps(variables)) else: tf.logging.info("%s (%s parameters)", label, paramcount(variables)) return variables tf.logging.info( "Training %d parameters (%.2fM) out of %d parameters (%.2fM)" % ( paramcount(train_vars), paramcount(train_vars) / (1024.0 * 1024.0), paramcount(tf.trainable_variables()), paramcount(tf.trainable_variables()) / (1024.0 * 1024.0), )) tf.logging.info("---------") tf.logging.info("Variable details:") logvars(train_vars, "trainable variables", print_variables=True) logvars(non_train_vars, "non-trainable variables", print_variables=True) logvars(other_vars, "other global variables", print_variables=True) logvars(local_vars, "other local variables", print_variables=True) tf.logging.info("---------") tf.logging.info("Variable summary:") logvars(train_vars, "trainable variables") logvars(non_train_vars, "non-trainable variables") logvars(other_vars, "other global variables") logvars(local_vars, "other local variables") G_loss = losses.generator_loss(fakes_D) D_loss_real, D_loss_fake = losses.discriminator_loss(reals_D, fakes_D) D_loss = D_loss_real + D_loss_fake #loss = tf.constant(0.0) loss = G_loss + D_loss optimizer = tf.train.AdamOptimizer() if params['use_tpu']: optimizer = tf.tpu.CrossShardOptimizer(optimizer) #import pdb; pdb.set_trace() update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # To update batchnorm, if present pp(['tf.GraphKeys.UPDATE_OPS', update_ops]) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, var_list=T_vars, global_step=global_step) return train_op, loss #D_loss_real
def train(x, y): G.module.optim.zero_grad() D.module.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config.batch_size) y = torch.split(y, config.batch_size) counter = 0 # Optionally toggle D and G's "require_grad" toggle_grad(D, True) toggle_grad(G, False) for step_index in range(config.num_D_steps): z_.sample_() y_.sample_() D_fake, D_real, mi, c_cls = GD(z_[:config.batch_size], y_[:config.batch_size], x[counter], y[counter], train_G=False) D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) if config.loss_type == 'Twin_AC': D_loss = (D_loss_real + D_loss_fake) + config.C_w * ( F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) + F.cross_entropy(mi[:D_fake.shape[0]], y_)) elif config.loss_type == 'AC': D_loss = (D_loss_real + D_loss_fake) + config.C_w * F.cross_entropy( c_cls[D_fake.shape[0]:], y[counter]) else: D_loss = (D_loss_real + D_loss_fake) (D_loss).backward() counter += 1 D.module.optim.step() # Optionally toggle "requires_grad" toggle_grad(D, False) toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.module.optim.zero_grad() for step_index in range(config.num_G_steps): z_.sample_() y_.sample_() D_fake, mi, c_cls = GD(z_[:config.batch_size], y_[:config.batch_size], train_G=True) # D(fake_img, y_) G_loss = losses.generator_loss(D_fake) C_loss = 0 MI_loss = 0 if config.loss_type == 'Twin_AC': MI_loss = F.cross_entropy(mi, y_) C_loss = F.cross_entropy(c_cls, y_) ((G_loss - config.C_w * MI_loss + config.C_w * C_loss)).backward() elif config.loss_type == 'AC': C_loss = F.cross_entropy(c_cls, y_) ((G_loss + config.C_w * C_loss)).backward() else: (G_loss).backward() G.module.optim.step() out = { 'G_loss': G_loss, 'D_loss_real': D_loss_real, 'D_loss_fake': D_loss_fake, 'C_loss': C_loss, 'MI_loss': MI_loss } # Return G's loss and the components of D's loss. return out
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_fake, D_real, mi, c_cls = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) C_loss = 0 if config['loss_type'] == 'Twin_AC': C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:] ,y[counter]) + F.cross_entropy(mi[:D_fake.shape[0]] ,y_) if config['loss_type'] == 'Twin_AC_M': C_loss += hinge_multi(c_cls[D_fake.shape[0]:], y[counter]) + hinge_multi(mi[:D_fake.shape[0]], y_) if config['loss_type'] == 'AC': C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:] ,y[counter]) D_loss = (D_loss_real + D_loss_fake + C_loss*config['AC_weight']) / float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # Optionally toggle "requires_grad" utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() for step_index in range(config['num_G_steps']): for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake, G_z, mi, c_cls = GD(z_, y_, train_G=True, split_D=config['split_D'], return_G_z=True) C_loss = 0 MI_loss = 0 if config['loss_type'] == 'AC' or config['loss_type'] == 'Twin_AC': C_loss = F.cross_entropy(c_cls, y_) if config['loss_type'] == 'Twin_AC': MI_loss = F.cross_entropy(mi, y_) if config['loss_type'] == 'Twin_AC_M': C_loss = hinge_multi(c_cls, y_,hinge=False) MI_loss = hinge_multi(mi, y_, hinge=False) G_loss = losses.generator_loss(D_fake) / float(config['num_G_accumulations']) C_loss = C_loss / float(config['num_G_accumulations']) MI_loss = MI_loss / float(config['num_G_accumulations']) (G_loss + (C_loss - MI_loss)*config['AC_weight']).backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), 'C_loss': C_loss, 'MI_loss': MI_loss} # Return G's loss and the components of D's loss. return out
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): net_g, net_d = nets optim_g, optim_d = optims scheduler_g, scheduler_d = schedulers train_loader, eval_loader = loaders if writers is not None: writer, writer_eval = writers train_loader.batch_sampler.set_epoch(epoch) global global_step net_g.train() net_d.train() for batch_idx, (spec, spec_lengths, y, y_lengths) in enumerate(train_loader): spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) with autocast(enabled=hps.train.fp16_run): mel = spec_to_mel_torch( spec, hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.mel_fmin, hps.data.mel_fmax) # print('check',mel.shape)/ y_hat, ids_slice, x_mask, z_mask,\ (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(mel, spec_lengths, spec, spec_lengths) # print('check',log_det_j_sum.shape, m_p.shape) y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1), hps.data.filter_length, hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, hps.data.mel_fmin, hps.data.mel_fmax ) y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice # NDA is effective? batch_size=y.size(0) y_jig1 = y.view(batch_size,4,-1) rand_idx = torch.randperm(4) y_jig2 = y_jig1[:,rand_idx,:] y_jigsaw = y_jig2.view(batch_size,1,-1) # print(rand_idx) check_idx = torch.tensor([0,1,2,3]) if (rand_idx ==check_idx).sum()==4: y_jigsaw = y_hat else: y_jigsaw = y_jigsaw y_negative = 0.75*y_hat + 0.25*y_jigsaw # Discriminator y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_negative.detach()) with autocast(enabled=False): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) loss_disc_all = loss_disc optim_d.zero_grad() scaler.scale(loss_disc_all).backward() scaler.unscale_(optim_d) grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) scaler.step(optim_d) with autocast(enabled=hps.train.fp16_run): # Generator y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) scaler.step(optim_g) scaler.update() if rank==0: if global_step % hps.train.log_interval == 0: lr = optim_g.param_groups[0]['lr'] losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl] logger.info('Train Epoch: {} [{:.0f}%]'.format( epoch, 100. * batch_idx / len(train_loader))) logger.info([x.item() for x in losses] + [global_step, lr]) scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) image_dict = { "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), } utils.summarize( writer=writer, global_step=global_step, images=image_dict, scalars=scalar_dict) if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) global_step += 1 if rank == 0: logger.info('====> Epoch: {}'.format(epoch))
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_fake, D_real, mi, c_cls, G_z = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D'], return_G_z=True) D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations C_loss = 0 if config['loss_type'] == 'Twin_AC': C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) + F.cross_entropy( mi[:D_fake.shape[0]], y_) if config['loss_type'] == 'AC': C_loss += F.cross_entropy(c_cls[D_fake.shape[0]:], y[counter]) if config['Pac']: T_img = x[counter].view(-1, 4 * x[counter].size()[1], x[counter].size()[2], x[counter].size()[3]) F_img = G_z.view(-1, 4 * G_z.size()[1], G_z.size()[2], G_z.size()[3]) pack_img = torch.cat([T_img, F_img], dim=0) pack_out, _, _ = D(pack_img, pack=True) D_real_pac = pack_out[:T_img.size()[0]] D_fake_pac = pack_out[T_img.size()[0]:] D_loss_real_pac, D_loss_fake_pac = losses.discriminator_loss( D_fake_pac, D_real_pac) D_loss_real += D_loss_real_pac D_loss_fake += D_loss_fake_pac D_loss = (D_loss_real + D_loss_fake + C_loss * config['AC_weight']) / float( config['num_D_accumulations']) D_loss.backward() counter += 1 D.optim.step() # Optionally toggle "requires_grad" utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() for step_index in range(config['num_G_steps']): for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake, G_z, mi, c_cls = GD(z_, y_, train_G=True, split_D=config['split_D'], return_G_z=True) C_loss = 0 MI_loss = 0 G_loss = losses.generator_loss(D_fake) if config['loss_type'] == 'AC' or config[ 'loss_type'] == 'Twin_AC': C_loss = F.cross_entropy(c_cls, y_) if config['loss_type'] == 'Twin_AC': MI_loss = F.cross_entropy(mi, y_) if config['Pac']: F_img = G_z.view(-1, 4 * G_z.size()[1], G_z.size()[2], G_z.size()[3]) D_fake_pac, _, _ = D(F_img, pack=True) G_loss_pac = losses.generator_loss(D_fake_pac) G_loss += G_loss_pac G_loss = G_loss / float(config['num_G_accumulations']) C_loss = C_loss / float(config['num_G_accumulations']) MI_loss = MI_loss / float(config['num_G_accumulations']) (G_loss + (C_loss - MI_loss) * config['AC_weight']).backward() G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), 'C_loss': C_loss, 'MI_loss': MI_loss } # Return G's loss and the components of D's loss. return out
def train(self): text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() gen_iterations = 0 # gen_iterations = start_epoch * self.num_batches for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) step = 0 while step < self.num_batches: # reset requires_grad to be trainable for all Ds # self.set_requires_grad_value(netsD, True) ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = data_iter.next() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) ####################################################### # (3) Update D network ###################################################### errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels) # backward and update parameters errD.backward() optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD.item()) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G step += 1 gen_iterations += 1 # do not need to compute gradient for Ds # self.set_requires_grad_value(netsD, False) netG.zero_grad() errG_total, G_logs = \ generator_loss(netsD, image_encoder, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens, class_ids) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss G_logs += 'kl_loss: %.2f ' % kl_loss.item() # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: print(D_logs + '\n' + G_logs) # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, name='average') load_params(netG, backup_para) # # self.save_img_results(netG, fixed_noise, sent_emb, # words_embs, mask, image_encoder, # captions, cap_lens, # epoch, name='current') end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD_total.item(), errG_total.item(), end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch)
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() E.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) # print("inside fns", x) print("split - x {}".format(len(x))) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) utils.toggle_grad(E, False) # print("inside train fns: config['num_D_steps']", config['num_D_steps']) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() # print("---------------------- counter {} ---------------".format(counter)) # print("x[counter] {}; y[counter] {}".format(x[counter].shape, y[counter].shape)) for accumulation_index in range(config['num_D_accumulations']): # Cornner case for the last batch if counter >= len(x): break D_fake, D_real = GDE(x[counter], y[counter], config, state_dict['itr'], img_pool, train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( \ D_fake, D_real, config['clip']) D_loss = (D_loss_real + D_loss_fake) / \ float(config['num_D_accumulations']) print("D_loss: {}; D_fake {}, D_real {}".format(D_loss.item(), D_loss_fake.item(), D_loss_real.item())) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) # stop gradient for testing purpose if config['stop_gradient']: print("!!! D is not optimized since you turn on `stop_gradient`!!!!!!") else: D.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) utils.toggle_grad(E, True) # Zero G/E's gradients by default before training G, for safety G.optim.zero_grad() E.optim.zero_grad() # If accumulating gradients, loop multiple times counter = 0 # reset counter for data split for accumulation_index in range(config['num_G_accumulations']): if counter >= len(x): break # print("---------------------- counter {} ---------------".format(counter)) output = GDE(x[counter], y[counter], config, state_dict['itr'], img_pool, train_G=True, split_D=config['split_D'], return_G_z=True) D_fake = output[0] G_z = output[2] mu, log_var = output[3], output[4] if len(output) == 6: G_additional = output[5] # print("checkpoint==========================") G_loss = losses.generator_loss( D_fake) / float(config['num_G_accumulations']) VAE_recon_loss = losses.vae_recon_loss(G_z, x[counter]) VAE_kld_loss = losses.vae_kld_loss(mu, log_var, config['clip']) GE_loss = G_loss + VAE_recon_loss * config['lambda_vae_recon'] + VAE_kld_loss * config['lambda_vae_kld'] # weights_TTs.mean() * config['lambda_spatial_transform_weights'] # log_loss_str = f"GE_loss {GE_loss.item()}; VAE_recon_loss {VAE_recon_loss.item()}; VAE_kld_loss {VAE_kld_loss.item()}; weights_TTs {weights_TTs.mean().item()}; " log_loss_str = f"GE_loss {GE_loss.item()}; VAE_recon_loss {VAE_recon_loss.item()}; VAE_kld_loss {VAE_kld_loss.item()} " # add G_additional loss if len(output) == 6: G_additional_loss = config['lambda_g_additional'] * G_additional.sum() GE_loss += G_additional_loss log_loss_str += f"G_additional {G_additional_loss.item()}" # print out loss print(log_loss_str) # optimization GE_loss.backward() counter += 1 # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) # stop gradient for testing purpose if config['stop_gradient']: print("!!! G and E is not optimized since you turn on `stop_gradient`!!!!!!") else: G.optim.step() E.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), 'VAE_recon_loss': float(VAE_recon_loss.item()), 'VAE_KLD_loss': float(VAE_recon_loss.item())} # Return G's loss and the components of D's loss. return out
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() inner_iter_count = 0 partial_test_input = 0 # How many chunks to split x and y into? #x = torch.split(x, config['batch_size']) #y = torch.split(y, config['batch_size']) #print('x len{}'.format(len(x))) #print('y len{}'.format(len(y))) #assert len(x) == config['num_D_accumulations'] == len(y) #D_fake, D_real, G_fake, gy = None, None, None, None # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() d_reals = None#[None for _ in x] g_fakes = None#[None for _ in x] #gys = [None for _ in x] #zs = [None for _ in x] #zs_.sample_() #ys_.sample_() #gy = ys_[:config['batch_size']] #z = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5] if state_dict['epoch'] < 0: #for accumulation_index in range(config['num_D_accumulations']): # doesn't mean anything right now # for fb_iter in range(config['num_feedback_iter']): # if fb_iter == 0: # z_ = zs_[:config['batch_size']] # gy = ys_[:config['batch_size']] # print('z_ shape {}'.format(z_.shape)) # z_ = z_.view(zs_.size(0), 9, 8, 8)[:, :5] zs_.sample_() z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:,20] # [:, :5] #z_ = z_.view(z_.size(0), -1) # zs[accumulation_index] = z # z_ = torch.cat([z, torch.zeros(zs_.size(0), 4, 8, 8).cuda()], 1) ys_.sample_() gy = ys_[:config['batch_size']] # gys[accumulation_index] = gy.detach() # else: # D_real = D_real#.repeat(1,3,1,1)# * g_fakes[accumulation_index] # print('zs_ shape 0 {}'.format(zs_.shape)) # print('\n\n\n\n') # print('r shape {}'.format(r.shape)) # print('g fake shape {}'.format(g_fakes[accumulation_index].shape)) # print('\n\n\n\n') # z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :8] # G_fake = nn.AvgPool2d(4)(g_fakes[accumulation_index]) # print('z shape 5 {}'.format(z_.shape)) # z_=z_[:,:3] # print('z shape 10 {}'.format(z_.shape)) # z_ = torch.cat([d_reals[accumulation_index], G_fake, zs[accumulation_index]], 1) # print('z shape 15 {}'.format(z_.shape)) # gy = gys[accumulation_index] D_fake, D_real, G_fake = GD(z_, gy, x=x,#[accumulation_index], dy=y,#[accumulation_index], train_G=False, split_D=config['split_D']) #print('D shape {}'.format(D_fake.shape)) #print('G fake shape {}'.format(nn.AvgPool2d(4)(G_fake).shape)) #print('D real shape {}'.format(D_real.shape)) #print('z shape {}'.format(z_.shape)) if state_dict['itr'] % 1000 == 0: ##and accumulation_index == 6: print('saving img') torchvision.utils.save_image(x.float().cpu(),#[accumulation_index].float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_xreal.jpg'.format( time, state_dict['itr']), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) torchvision.utils.save_image(D_fake.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dfake.jpg'.format( time, state_dict['itr']), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) torchvision.utils.save_image(D_real.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dreal.jpg'.format( time, state_dict['itr']), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) # d_reals[accumulation_index] = D_real.detach() # g_fakes[accumulation_index] = G_fake.detach() # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) D_loss = (D_loss_real + D_loss_fake)# / float(config['num_D_accumulations']) D_loss.backward() # counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() # D.optim.zero_grad() # Optionally toggle "requires_grad" else: for fb_iter in range(config['num_feedback_iter_D']): #for accumulation_index in range(config['num_D_accumulations']): #doesn't mean anything right now #for fb_iter in range(config['num_feedback_iter']): zs_.sample_() z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20] ys_.sample_() gy = ys_[:config['batch_size']] if fb_iter == 0: # z_ = zs_[:config['batch_size']] # gy = ys_[:config['batch_size']] #print('z_ shape {}'.format(z_.shape)) #z_ = z_.view(zs_.size(0), 9, 8, 8)[:, :5] #zs_.sample_() #z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:, :20] #zs[accumulation_index] = z_ #print('three channel x input train D shape before {}'.format(x[:, :3].shape)) #init_x = nn.AvgPool2d(4)(x[:, :3]) init_x = x[:, :3] z_ = torch.cat([z_, init_x, torch.ones(zs_.size(0), 1, 32, 32).cuda()], 1) #print('three channel x input train D shape after {}'.format(nn.AvgPool2d(4)(x[:, :3]).shape)) #ys_.sample_() #gy = ys_[:config['batch_size']] #gys[accumulation_index] = gy.detach() else: #D_real = D_real#.repeat(1,3,1,1)# * g_fakes[accumulation_index] #print('zs_ shape 0 {}'.format(zs_.shape)) #print('\n\n\n\n') #print('r shape {}'.format(r.shape)) #print('g fake shape {}'.format(g_fakes[accumulation_index].shape)) #print('\n\n\n\n') #z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :8] g_fake = 0.1 * g_fake + 0.9 * init_x#[accumulation_index] #print('z shape 5 {}'.format(z_.shape)) #z_=z_[:,:3] # print('z shape 10 {}'.format(z_.shape)) # print('g fake shape 10 {}'.format(G_fake.shape)) # print('d real shape 10 {}'.format(d_reals.shape)) #z_ = torch.cat([zs[accumulation_index],d_reals[accumulation_index], G_fake,], 1) z_ = torch.cat([z_, g_fake, nn.functional.interpolate(d_reals, 32, mode='bilinear')#[accumulation_index] ,], 1) #z_ = z_.view(z_.size(0),-1) #print('z shape 15 {}'.format(z_.shape)) #gy = gys[accumulation_index] # if state_dict['itr'] % 42 == 0: # partial_test_input = partial_test_input + torch.cat([g_fakes, d_fakes]) D_fake, D_real, G_fake = GD(z_, gy, x=x,#[accumulation_index], dy=y,#[accumulation_index], train_G=False, split_D=config['split_D']) #print('D shape {}'.format(D_fake.shape)) if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6: print('saving img') torchvision.utils.save_image(x.float().cpu(),#[accumulation_index].float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_xreal.jpg'.format( time, state_dict['itr'], fb_iter), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) torchvision.utils.save_image(G_fake.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_Gfake_d.jpg'.format( time,state_dict['itr'],fb_iter),nrow=int(D_fake.shape[0] ** 0.5),normalize=True) if fb_iter > 1: torchvision.utils.save_image(g_fake.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_gfake_d.jpg'.format( time,state_dict['itr'],fb_iter),nrow=int(D_fake.shape[0] ** 0.5),normalize=True) D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) if not fb_iter == 0: # d_real_enforcement = losses.loss_enforcing(d_reals#[accumulation_index] # , D_real) # g_fakes_enforcement = losses.loss_enforcing(g_fakes #[accumulation_index] # , nn.AvgPool2d(4)(G_fake)) D_loss = (D_loss_real + D_loss_fake)# + 0.1 * d_real_enforcement)# / float(config['num_D_accumulations']) else: D_loss = (D_loss_real + D_loss_fake)# / float(config['num_D_accumulations']) #d_reals[accumulation_index] = D_real.detach() d_reals = D_real.detach() #g_fakes[accumulation_index] = nn.AvgPool2d(4)(G_fake).detach() g_fake = G_fake.detach() #g_fakes = G_fake.detach() # Compute components of D's loss, average them, and divide by # the number of gradient accumulations # D_loss_real, D_loss_fake = losses.discriminator_loss(D_fake, D_real) # if not fb_iter == 0: # D_loss = (D_loss_real + D_loss_fake + d_real_enforcement + g_fakes_enforcement) / float(config['num_D_accumulations']) # else: # D_loss = (D_loss_real + D_loss_fake) / float(config['num_D_accumulations']) D_loss.backward() #counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. # print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() #D.optim.zero_grad() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() #d_fakes = [None for _ in range(config['num_G_accumulations'])] #g_fakes = [None for _ in range(config['num_G_accumulations'])] #gys = [None for _ in range(config['num_G_accumulations'])] #for fb_iter in range(config['num_feedback_iter']): # If accumulating gradients, loop multiple times d_fakes = None#[None for _ in x] g_fakes = None#[None for _ in x] #gys = [None for _ in x] #zs = [None for _ in x] if state_dict['epoch'] < 0: #for accumulation_index in range(config['num_G_accumulations']): # doesn't mean anything right now zs_.sample_() z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20] #zs[accumulation_index] = z_[:, :5] # z_ = torch.cat([z, torch.zeros(zs_.size(0), 4, 8, 8).cuda()],1) ys_.sample_() gy = ys_ #gys[accumulation_index] = gy.detach() # D_fake = D_fake.repeat(1,3,1,1) # z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5] #G_fake = nn.AvgPool2d(4)(g_fakes[accumulation_index]) #z_ = torch.cat([d_fakes[accumulation_index], G_fake, zs[accumulation_index]], 1) # gy = gys[accumulation_index] z_ = z_.view(z_.size(0), -1) D_fake, G_z = GD(z=z_, gy=gy, train_G=True, split_D=config['split_D'], return_G_z=True) G_loss = losses.generator_loss(D_fake)# / float(config['num_G_accumulations']) G_loss.backward() if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6: print('saving img') torchvision.utils.save_image(D_fake.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_dfake.jpg'.format( time, state_dict['itr'],), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) torchvision.utils.save_image(G_z.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_pre_G_z.jpg'.format( time, state_dict['itr'],), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) #g_fakes[accumulation_index] = G_z.detach() #d_fakes[accumulation_index] = D_fake.detach() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() # G.optim.zero_grad() else: for fb_iter in range(config['num_feedback_iter']): #for accumulation_index in range(config['num_G_accumulations']): #doesn't mean anything right now zs_.sample_() z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 32, 32)[:, :20] ys_.sample_() gy = ys_ if fb_iter <= 1: #zs_.sample_() #z_ = zs_[:config['batch_size']].view(zs_.size(0), 24, 8, 8)[:, :20] #zs[accumulation_index] = z_ #print('three channel x input train G shape before {}'.format(x.shape)) #init_x = nn.AvgPool2d(4)(x[:, :3]) init_x = x[:, :3] z_ = torch.cat([z_, init_x, torch.ones(zs_.size(0), 1, 32, 32).cuda()], 1) #print('three channel x input train G shape after {}'.format(nn.AvgPool2d(4)(x[:, :3]).shape)) #ys_.sample_() #gy = ys_ #gys[accumulation_index] = gy.detach() else: #D_fake = D_fake.repeat(1,3,1,1) #z_ = zs_[:config['batch_size']].view(zs_.size(0), 9, 8, 8)[:, :5] #G_fake = g_fakes#[accumulation_index] g_fake = 0.05 * g_fakes + 0.95 * init_x # [accumulation_index] d_fakes = nn.functional.interpolate(d_fakes, 32, mode='bilinear')#[accumulation_index] #z_ = torch.cat([zs[accumulation_index], d_fakes[accumulation_index], G_fake, ], 1) z_ = torch.cat([z_, g_fake, d_fakes #[accumulation_index] ,], 1) if ((not (state_dict['itr'] % config['save_every'])) or (not (state_dict['itr'] % config['test_every']))): partial_test_input = partial_test_input + torch.cat([g_fake, d_fakes], 1) inner_iter_count = inner_iter_count + 1 #gy = gys[accumulation_index] #z_ = z_.view(z_.size(0), -1) D_fake, G_z = GD(z=z_, gy=gy, train_G=True, split_D=config['split_D'], return_G_z=True) if not fb_iter == 0: #g_fakes_enforcement = losses.loss_enforcing(g_fakes#[accumulation_index] #, G_z) # d_fakes_enforcement = losses.loss_enforcing(d_fakes#[accumulation_index] # , D_fake) G_loss = (losses.generator_loss(D_fake))# + 0.1 * g_fakes_enforcement) #/ float(config['num_G_accumulations']) else: G_loss = (losses.generator_loss(D_fake))# / float(config['num_G_accumulations']) G_loss.backward() if state_dict['itr'] % 1000 == 0:# and accumulation_index == 6: print('saving img') # torchvision.utils.save_image(D_fake.float().cpu(), # '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_dfake.jpg'.format(time, # state_dict['itr'], fb_iter), # nrow=int(D_fake.shape[0] ** 0.5), # normalize=True) torchvision.utils.save_image(G_z.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_G_z.jpg'.format(time, state_dict['itr'], fb_iter), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) if fb_iter > 1: torchvision.utils.save_image(g_fake.float().cpu(), '/ubc/cs/research/shield/projects/cshen001/BigGAN-original/BigGAN-PyTorch/samples_new/{}_it{}_fb{}_G_z_input.jpg'.format(time, state_dict['itr'], fb_iter), nrow=int(D_fake.shape[0] ** 0.5), normalize=True) #g_fakes[accumulation_index] = nn.AvgPool2d(4)(G_z).detach() g_fakes = G_z.detach() #g_fakes = G_z.detach() #d_fakes[accumulation_index] = D_fake.detach() d_fakes = D_fake.detach() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G') # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G.optim.step() #G.optim.zero_grad() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item())} # Return G's loss and the components of D's loss. partial_test_input = partial_test_input / (inner_iter_count + 1e-9) return out, partial_test_input
def train(self): text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models( ) avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = self.opts.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) noise, fixed_noise = noise.to(self.device), fixed_noise.to(self.device) gen_iterations = 0 lr_schedulers = [] if self.use_lr_scheduler: for i in range(len(optimizersD)): lr_scheduler = LambdaLR(optimizersD[i], lr_lambda=lambda epoch: 0.998**epoch) for m in range(start_epoch): lr_scheduler.step() lr_schedulers.append(lr_scheduler) # gen_iterations = start_epoch * self.num_batches for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.train_loader) step = 0 for i in range(len(lr_schedulers)): lr_schedulers[i].step() while step < self.num_batches: # reset requires_grad to be trainable for all Ds # self.set_requires_grad_value(netsD, True) ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = next(data_iter) imgs, captions, class_ids, captions_mask = prepare_data( data, self.device) words_embs, sent_emb = self.text_encoder_forward( text_encoder, captions, captions_mask) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) ####################################################### # (3) Update D network ###################################################### errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels) # backward and update parameters errD.backward() optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD.data.item()) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G step += 1 gen_iterations += 1 # do not need to compute gradient for Ds # self.set_requires_grad_value(netsD, False) netG.zero_grad() errG_total, G_logs = \ generator_loss(netsD, image_encoder, fake_imgs, real_labels, words_embs, sent_emb, match_labels, class_ids, self.opts) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss G_logs += 'kl_loss: %.2f ' % kl_loss.data.item() # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 10 == 0: print("Epoch: " + str(epoch) + " Step: " + str(step) + " " + D_logs + '\n' + G_logs) # save images if gen_iterations % 300 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, epoch, step, name='average') load_params(netG, backup_para) is_mean, is_std, error_G_val = self.validate( netG, netsD, text_encoder, image_encoder) self.val_logger.write("{} {} {}\n".format(epoch, is_mean, is_std)) self.val_logger.flush() self.losses_logger.write("{} {} {}\n".format( epoch, errG_total.data.item(), error_G_val)) self.losses_logger.flush() end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD_total.data.item(), errG_total.data.item(), end_t - start_t)) print("IS: {} {}".format(is_mean, is_std)) if epoch % self.opts.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch)
def find_lr(data, batch_size=1, start_lr=1e-9, end_lr=1): generator = Generator(input_shape=[None,None,2]) discriminator = Discriminator(input_shape=[None,None,1]) generator_optimizer = tf.keras.optimizers.Adam(lr=start_lr) discriminator_optimizer = tf.keras.optimizers.Adam(lr=start_lr) model_name = data['training'].origin+'_2_any' checkpoint_prefix = os.path.join(CHECKPOINT_DIR, model_name) if(not os.path.isdir(checkpoint_prefix)): os.makedirs(checkpoint_prefix) epoch_size = data['training'].__len__() lr_mult = (end_lr / start_lr) ** (1 / epoch_size) lrs = [] losses = { 'gen_mae': [], 'gen_loss': [], 'disc_loss': [] } best_losses = { 'gen_mae': 1e9, 'gen_loss': 1e9, 'disc_loss': 1e9 } print() print("Finding the optimal LR with the following parameters: ") print("\tCheckpoints: \t", checkpoint_prefix) print("\tEpochs: \t", 1) print("\tBatchSize: \t", batch_size) print("\tnBatches: \t", epoch_size) print() print('Epoch {}/{}'.format(1, 1)) progbar = tf.keras.utils.Progbar(epoch_size) for i in range(epoch_size): # Get the data from the DataGenerator input_image, target = data['training'].__getitem__(i) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: # Generate a fake image gen_output = generator(input_image, training=True) # Train the discriminator disc_real_output = discriminator([input_image[:,:,:,0:1], target], training=True) disc_generated_output = discriminator([input_image[:,:,:,0:1], gen_output], training=True) # Compute the losses gen_mae = l1_loss(target, gen_output) gen_loss = generator_loss(disc_generated_output, gen_mae) disc_loss = discriminator_loss(disc_real_output, disc_generated_output) # Compute the gradients generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables) discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables) # Apply the gradients generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) # Convert losses to numpy gen_mae = gen_mae.numpy() gen_loss = gen_loss.numpy() disc_loss = disc_loss.numpy() # Update the progress bar progbar.add(1, values=[ ("gen_mae", gen_mae), ("gen_loss", gen_loss), ("disc_loss", disc_loss) ]) # On batch end lr = tf.keras.backend.get_value(generator_optimizer.lr) lrs.append(lr) # Update the lr lr *= lr_mult tf.keras.backend.set_value(generator_optimizer.lr, lr) tf.keras.backend.set_value(discriminator_optimizer.lr, lr) # Update the losses losses['gen_mae'].append(gen_mae) losses['gen_loss'].append(gen_loss) losses['disc_loss'].append(disc_loss) # Update the best losses if(best_losses['gen_mae'] > gen_mae): best_losses['gen_mae'] = gen_mae if(best_losses['gen_loss'] > gen_loss): best_losses['gen_loss'] = gen_loss if(best_losses['disc_loss'] > disc_loss): best_losses['disc_loss'] = disc_loss if(gen_mae >= 100*best_losses['gen_mae'] or gen_loss >= 100*best_losses['gen_loss'] or disc_loss >= 100*best_losses['disc_loss']): break plot_loss_findlr(losses['gen_mae'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_mae.tiff')) plot_loss_findlr(losses['gen_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_gen_loss.tiff')) plot_loss_findlr(losses['disc_loss'], lrs, os.path.join(checkpoint_prefix, 'LRFinder_disc_loss.tiff')) print('Best losses:') print('gen_mae =', best_losses['gen_mae']) print('gen_loss =', best_losses['gen_loss']) print('disc_loss =', best_losses['disc_loss'])
def optimize(self, data, current_step): if config.use_apex: from apex import amp losses_dict = OrderedDict() for param in self.discriminator.parameters(): param.requires_grad = True for param in self.generator_params: param.requires_grad = True pseudo_labels_a, embeddings_a = self.encoder(data['image_a']) pseudo_labels_b, embeddings_b = self.encoder(data['image_b']) if config.use_mixing: num_0d_units = 1 if config.size_0d_unit > 0 else 0 random = np.random.randint( 2, size=[ num_0d_units + config.num_1d_units + config.num_2d_units, config.batch_size, 1, 1 ]).tolist() random_tensor = torch.tensor(random, dtype=torch.float, requires_grad=False).to("cuda") normalized_embeddings_from_a_mix = self.rotate(embeddings_a, pseudo_labels_a, random_tensor, inverse=True) embeddings_a_to_mix = self.rotate(normalized_embeddings_from_a_mix, pseudo_labels_b, random_tensor) image_mix_hat = self.decoder(embeddings_a_to_mix) pseudo_labels_mix_hat, embeddings_mix_hat = self.encoder( image_mix_hat) # a -> b normalized_embeddings_from_a = self.rotate(embeddings_a, pseudo_labels_a, inverse=True) embeddings_a_to_b = self.rotate(normalized_embeddings_from_a, pseudo_labels_b) image_b_hat = self.decoder(embeddings_a_to_b) # optimize discriminator real = self.discriminator(data['image_b']) fake = self.discriminator(image_b_hat.detach()) losses_dict['discriminator'] = losses.discriminator_loss(real=real, fake=fake) losses_dict['generator'] = losses.generator_loss(fake=fake) discriminator_loss = losses_dict[ 'discriminator'] * config.coeff_discriminator_loss # Warm up period for generator losses losses_dict['discrim_coeff'] = torch.tensor( max(min(1.0, current_step / 20000.0), 0.0)) self.discriminator_optimizer.zero_grad() if config.use_apex: with amp.scale_loss(discriminator_loss, self.discriminator_optimizer) as scaled_loss: scaled_loss.backward() else: discriminator_loss.backward() self.discriminator_optimizer.step() for param in self.discriminator.parameters(): param.requires_grad = False # for generator update losses_dict['l1'] = losses.reconstruction_l1_loss(x=data['image_b'], x_hat=image_b_hat) total_loss = losses_dict['l1'] * config.coeff_l1_loss if not config.semi_supervised: losses_dict['gaze_a'] = (losses.gaze_angular_loss( y=data['gaze_a'], y_hat=pseudo_labels_a[-1]) + losses.gaze_angular_loss( y=data['gaze_b'], y_hat=pseudo_labels_b[-1])) / 2 losses_dict['head_a'] = (losses.gaze_angular_loss( y=data['head_a'], y_hat=pseudo_labels_a[-2]) + losses.gaze_angular_loss( y=data['head_b'], y_hat=pseudo_labels_b[-2])) / 2 else: losses_dict['gaze_a'] = losses.gaze_angular_loss( y=data['gaze_a'], y_hat=pseudo_labels_a[-1]) losses_dict['head_a'] = losses.gaze_angular_loss( y=data['head_a'], y_hat=pseudo_labels_a[-2]) losses_dict['gaze_a_unlabeled'] = losses.gaze_angular_loss( y=data['gaze_b'], y_hat=pseudo_labels_b[-1]) losses_dict['head_a_unlabeled'] = losses.gaze_angular_loss( y=data['head_b'], y_hat=pseudo_labels_b[-2]) total_loss += (losses_dict['gaze_a'] + losses_dict['head_a']) * config.coeff_gaze_loss fake = self.discriminator(image_b_hat) generator_loss = losses.generator_loss(fake=fake) total_loss += generator_loss * config.coeff_discriminator_loss * losses_dict[ 'discrim_coeff'] if config.coeff_embedding_consistency_loss != 0: normalized_embeddings_from_a = self.rotate(embeddings_a, pseudo_labels_a, inverse=True) normalized_embeddings_from_b = self.rotate(embeddings_b, pseudo_labels_b, inverse=True) flattened_normalized_embeddings_from_a = torch.cat([ e.reshape(e.shape[0], -1) for e in normalized_embeddings_from_a ], dim=1) flattened_normalized_embeddings_from_b = torch.cat([ e.reshape(e.shape[0], -1) for e in normalized_embeddings_from_b ], dim=1) losses_dict['embedding_consistency'] = (1.0 - torch.mean( F.cosine_similarity(flattened_normalized_embeddings_from_a, flattened_normalized_embeddings_from_b, dim=-1))) total_loss += losses_dict[ 'embedding_consistency'] * config.coeff_embedding_consistency_loss if config.coeff_disentangle_embedding_loss != 0: assert config.use_mixing is True flattened_before_c = torch.cat( [e.reshape(e.shape[0], -1) for e in embeddings_a_to_mix], dim=1) flattened_after_c = torch.cat( [e.reshape(e.shape[0], -1) for e in embeddings_mix_hat], dim=1) losses_dict['embedding_disentangle'] = (1.0 - torch.mean( F.cosine_similarity( flattened_before_c, flattened_after_c, dim=-1))) total_loss += losses_dict[ 'embedding_disentangle'] * config.coeff_disentangle_embedding_loss if config.coeff_disentangle_pseudo_label_loss != 0: assert config.use_mixing is True losses_dict['label_disentangle'] = 0 pseudo_labels_a_b_mix = [] for i in range(len(pseudo_labels_a)): # pseudo code if pseudo_labels_b[i] is not None: pseudo_labels_a_b_mix.append( pseudo_labels_b[i] * random_tensor[i].squeeze(-1) + pseudo_labels_a[i] * (1 - random_tensor[i].squeeze(-1))) else: pseudo_labels_a_b_mix.append(None) for y, y_hat in zip(pseudo_labels_a_b_mix[-2:], pseudo_labels_mix_hat[-2:]): if y is not None: losses_dict[ 'label_disentangle'] += losses.gaze_angular_loss( y, y_hat) total_loss += losses_dict[ 'label_disentangle'] * config.coeff_disentangle_pseudo_label_loss feature_h, gaze_h, head_h = self.GazeHeadNet_train(image_b_hat, True) feature_t, gaze_t, head_t = self.GazeHeadNet_train( data['image_b'], True) losses_dict['redirection_feature_loss'] = 0 for i in range(len(feature_h)): losses_dict['redirection_feature_loss'] += nn.functional.mse_loss( feature_h[i], feature_t[i].detach()) total_loss += losses_dict[ 'redirection_feature_loss'] * config.coeff_redirection_feature_loss losses_dict['gaze_redirection'] = losses.gaze_angular_loss( y=gaze_t.detach(), y_hat=gaze_h) total_loss += losses_dict[ 'gaze_redirection'] * config.coeff_redirection_gaze_loss losses_dict['head_redirection'] = losses.gaze_angular_loss( y=head_t.detach(), y_hat=head_h) total_loss += losses_dict[ 'head_redirection'] * config.coeff_redirection_gaze_loss self.generator_optimizer.zero_grad() if config.use_apex: with amp.scale_loss(total_loss, [self.generator_optimizer]) as scaled_loss: scaled_loss.backward() else: total_loss.backward() self.generator_optimizer.step() return losses_dict, image_b_hat
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an # optimizer step D.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_, y_ = sample() D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / \ float(config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. xm.master_print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) xm.optimizer_step(D.optim) # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_, y_ = sample() D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) G_loss = losses.generator_loss( D_fake) / float(config['num_G_accumulations']) G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in G print('using modified ortho reg in G') # Don't ortho reg shared, it makes no sense. Really we should # blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) xm.optimizer_step(G.optim) # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = {'G_loss': G_loss, 'D_loss_real': D_loss_real, 'D_loss_fake': D_loss_fake} # Return G's loss and the components of D's loss. return out
def train(x, y, tensor_writer=None, iteration=None): print('Summation will be taken', config['D_hinge_loss_sum'], 'D hinge loss') G.optim.zero_grad() D.optim.zero_grad() if config['no_Dv'] == False: Dv.optim.zero_grad() if tensor_writer != None and iteration % config[ 'log_results_every'] == 0: tensor_writer.add_video('Loaded Data', (x + 1) / 2, iteration) mean_pixel_val = torch.mean((x + 1) / 2, dim=[0, 1, 3, 4]) tensor_writer.add_scalar( 'Pixel vals/Mean Red Pixel values, real data', float(mean_pixel_val[0].item()), iteration) tensor_writer.add_scalar( 'Pixel vals/Mean Green Pixel values, real data', float(mean_pixel_val[1].item()), iteration) tensor_writer.add_scalar( 'Pixel vals/Mean Blue Pixel values, real data', float(mean_pixel_val[2].item()), iteration) y_text = [] for yi in y: y_text.append(idx_to_classes[yi.item()]) tensor_writer.add_text('Loaded Labels', ' | '.join(y_text), iteration) #Added by Xiaodan: prepare for avg pixel loss if config['no_avg_pixel_loss'] == False: mean_pixel_val_real = torch.mean((x + 1) / 2) # print('Range of loaded data:',x.min(),'--',x.max()) # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) if config['no_Dv'] == False: utils.toggle_grad(Dv, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() if config['no_Dv'] == False: Dv.optim.zero_grad() for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() # print('z_ size in GAN tranining func:',z_.shape) # print('y_ size in GAN tranining func:',y_.shape) #xiaodan: D_fake, D_real [B*8,1] # print('hier and G_shared:',config['hier'],config['G_shared']) # print('Shape of z_[:config[batch_size]]:',z_[:config['batch_size']].shape) # print('config[batch_size]',config['batch_size']) if config['no_Dv'] == False: D_fake, D_real, Dv_fake, Dv_real, G_z = GD( z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D'], tensor_writer=tensor_writer, iteration=iteration) else: D_fake, D_real, G_z = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=False, split_D=config['split_D'], tensor_writer=tensor_writer, iteration=iteration) # print('GD.k in train_fns line 49',GD.module.k) #GD.module because GD is now dataparallel class # D_fake & D_real shapes: [Bk,1], [Bk,1] # xiaodan: Make scores back to [B,k,1] for easier summation in discriminator_loss D_fake = D_fake.contiguous().view(-1, GD.module.k, *D_fake.shape[1:]) #[B,k,1] D_real = D_real.contiguous().view(-1, GD.module.k, *D_real.shape[1:]) #[B,k,1] if config['D_hinge_loss_sum'] == 'before': D_fake = torch.sum( D_fake, 1 ) #xiaodan: add k scores before doing hinge loss, according to the paper D_real = torch.sum(D_real, 1) #[B,1] # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real, config['D_hinge_loss_sum']) # Dv_fake & Dv_real shapes: [BT*,1], [BT*,1] if T_into_B; [B,1], [B,1] if False if config['no_Dv'] == False: # print('Dv_fake shape',Dv_fake.shape) if config['T_into_B'] == True: Dv_fake = Dv_fake.contiguous().view( D_fake.shape[0], -1, *Dv_fake.shape[1:]) #[B,T*,1] Dv_real = Dv_real.contiguous().view( D_real.shape[0], -1, *Dv_real.shape[1:]) #[B,T*,1] if config['Dv_hinge_loss_sum'] == 'before': Dv_fake = torch.sum( Dv_fake, 1 ) #xiaodan: add T* scores before doing hinge loss Dv_real = torch.sum(Dv_real, 1) #[B,1] Dv_loss_real, Dv_loss_fake = losses.discriminator_loss( Dv_fake, Dv_real, config['Dv_hinge_loss_sum']) else: #Xiaodan: If T_into_B is False, must use "before" for hinge loss. Dv_loss_real, Dv_loss_fake = losses.discriminator_loss( Dv_fake, Dv_real, 'before') D_loss = (D_loss_real + D_loss_fake + Dv_loss_fake + Dv_loss_real) / float( config['num_D_accumulations']) else: D_loss = (D_loss_real + D_loss_fake) / float( config['num_D_accumulations']) D_loss.backward() counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. if config['no_Dv'] == False: print('using modified ortho reg in D and Dv') utils.ortho(Dv, config['D_ortho']) else: print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D.optim.step() if config['no_Dv'] == False: Dv.optim.step() # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) if config['no_Dv'] == False: utils.toggle_grad(Dv, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() # print('z_,y_ shapes before pass into GD:',z_.shape,y_.shape) if config['no_Dv'] == False: D_fake, Dv_fake, G_z = GD(z_, y_, train_G=True, split_D=config['split_D'], tensor_writer=tensor_writer, iteration=iteration) else: D_fake, G_z = GD(z_, y_, train_G=True, split_D=config['split_D'], tensor_writer=tensor_writer, iteration=iteration) D_fake = D_fake.contiguous().view(-1, GD.module.k, *D_fake.shape[1:]) #[B, k, 1] D_fake = torch.mean( D_fake, 1) # [B,1] xiaodan: average k scores before doing hinge loss G_loss = config['D_loss_weight'] * losses.generator_loss( D_fake) / float(config['num_G_accumulations']) if config['no_Dv'] == False: if config['T_into_B'] == True: Dv_fake = Dv_fake.contiguous().view( D_fake.shape[0], -1, *Dv_fake.shape[1:]) #[B,T*,1] Dv_fake = torch.mean(Dv_fake, 1) # [B,1] G_loss += losses.generator_loss(Dv_fake) / float( config['num_G_accumulations']) #Added by Xiaodan to take avg. pixel value into account as an additional losses # print(type(G_loss)) if config['no_avg_pixel_loss'] == False: mean_pixel_val_fake = torch.mean((G_z + 1) / 2) mean_pixel_val_diff = abs( float(mean_pixel_val_fake.item()) - float(mean_pixel_val_real.item())) mean_pixel_loss = losses.avg_pixel_loss( mean_pixel_val_diff, config['avg_pixel_loss_weight']) / float( config['num_G_accumulations']) if iteration >= config['pixel_loss_kicksin']: G_loss += mean_pixel_loss else: mean_pixel_loss = 0 G_loss.backward() # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G' ) # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) if config['no_convgru'] == False: G_grad_gates = G.convgru.convgru.cell_list[ 0].conv_gates.weight.grad.abs().sum() G_grad_can = G.convgru.convgru.cell_list[ 0].conv_can.weight.grad.abs().sum() G_grad_first_layer = G.blocks[0][0].conv1.weight.grad.abs().sum() G_weight_gates = G.convgru.convgru.cell_list[ 0].conv_gates.weight.abs().mean() G_weight_can = G.convgru.convgru.cell_list[0].conv_can.weight.abs( ).mean() G_weight_first_layer = G.blocks[0][0].conv1.weight.abs().mean() G.optim.step() # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) if config['no_Dv'] == False: out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()), 'Dv_loss_real': float(Dv_loss_real.item()), 'Dv_loss_fake': float(Dv_loss_fake.item()) } else: out = { 'G_loss': float(G_loss.item()), 'D_loss_real': float(D_loss_real.item()), 'D_loss_fake': float(D_loss_fake.item()) } if tensor_writer != None and iteration % config[ 'log_results_every'] == 0: tensor_writer.add_video('Video Results', (G_z + 1) / 2, iteration) mean_pixel_val = torch.mean((G_z + 1) / 2, dim=[0, 1, 3, 4]) tensor_writer.add_scalar( 'Pixel vals/Mean Red Pixel values, fake data', float(mean_pixel_val[0].item()), iteration) tensor_writer.add_scalar( 'Pixel vals/Mean Green Pixel values, fake data', float(mean_pixel_val[1].item()), iteration) tensor_writer.add_scalar( 'Pixel vals/Mean Blue Pixel values, fake data', float(mean_pixel_val[2].item()), iteration) y_Gz_text = [] for yi in y_: y_Gz_text.append(idx_to_classes[yi.item()]) tensor_writer.add_text('Generated Labels', ' | '.join(y_Gz_text), iteration) # Return G's loss and the components of D's loss. if config['no_avg_pixel_loss'] == False: tensor_writer.add_scalar('Loss/avg_pixel_loss', mean_pixel_loss, iteration) tensor_writer.add_scalar('Loss/G_loss', out['G_loss'], iteration) tensor_writer.add_scalar('Loss/D_loss_real', out['D_loss_real'], iteration) tensor_writer.add_scalar('Loss/D_loss_fake', out['D_loss_fake'], iteration) if config['no_Dv'] == False: tensor_writer.add_scalar('Loss/Dv_loss_fake', out['Dv_loss_fake'], iteration) tensor_writer.add_scalar('Loss/Dv_loss_real', out['Dv_loss_real'], iteration) if config['no_convgru'] == False: tensor_writer.add_scalar('Gradient/G_grad_gates', G_grad_gates, iteration) tensor_writer.add_scalar('Gradient/G_grad_can', G_grad_can, iteration) tensor_writer.add_scalar('Gradient/G_grad_first_layer', G_grad_first_layer, iteration) tensor_writer.add_scalar('Weight/G_weight_gates', G_weight_gates, iteration) tensor_writer.add_scalar('Weight/G_weight_can', G_weight_can, iteration) tensor_writer.add_scalar('Weight/G_weight_first_layer', G_weight_first_layer, iteration) return out