class RecurrentGAN(): def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN, self).__init__() # region Models-Instantiation self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer]( self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr ) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name) def train_batch(self, batch, epoch, iteration, visualizer, logger): """ The training scheme follows the following: - Discriminator and Generator is updated every time step. - RNN, SentenceEncoder and ImageEncoder parameters are updated every sequence """ batch_size = len(batch['image']) max_seq_len = batch['image'].size(1) prev_image = torch.FloatTensor(batch['background']) prev_image = prev_image.unsqueeze(0) \ .repeat(batch_size, 1, 1, 1) disc_prev_image = prev_image # Initial inputs for the RNN set to zeros hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim) prev_objects = torch.zeros(batch_size, self.cfg.num_objects) teller_images = [] drawer_images = [] added_entities = [] for t in range(max_seq_len): image = batch['image'][:, t] turns_word_embedding = batch['turn_word_embedding'][:, t] turns_lengths = batch['turn_lengths'][:, t] objects = batch['objects'][:, t] seq_ended = t > (batch['dialog_length'] - 1) image_feature_map, image_vec, object_detections = \ self.image_encoder(prev_image) _, current_image_feat, _ = self.image_encoder(image) turn_embedding = self.sentence_encoder(turns_word_embedding, turns_lengths) rnn_condition, current_image_feat = \ self.condition_encoder(turn_embedding, image_vec, current_image_feat) rnn_condition = rnn_condition.unsqueeze(0) output, hidden = self.rnn(rnn_condition, hidden) output = output.squeeze(0) output = self.layer_norm(output) fake_image, mu, logvar, sigma = self._forward_generator(batch_size, output.detach(), image_feature_map) visualizer.track_sigma(sigma) hamming = objects - prev_objects hamming = torch.clamp(hamming, min=0) d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \ self._optimize_discriminator(image, fake_image.detach(), disc_prev_image, output, seq_ended, hamming, self.cfg.gp_reg, self.cfg.aux_reg) g_loss, generator_gradient = \ self._optimize_generator(fake_image, disc_prev_image.detach(), output.detach(), objects, self.cfg.aux_reg, seq_ended, mu, logvar) if self.cfg.teacher_forcing: prev_image = image else: prev_image = fake_image disc_prev_image = image prev_objects = objects if (t + 1) % 2 == 0: prev_image = prev_image.detach() rnn_grads = [] gru_grads = [] condition_encoder_grads = [] img_encoder_grads = [] if t == max_seq_len - 1: rnn_gradient, gru_gradient, condition_gradient,\ img_encoder_gradient = self._optimize_rnn() rnn_grads.append(rnn_gradient.data.cpu().numpy()) gru_grads.append(gru_gradient.data.cpu().numpy()) condition_encoder_grads.append(condition_gradient.data.cpu().numpy()) if self.use_image_encoder: img_encoder_grads.append(img_encoder_gradient.data.cpu().numpy()) visualizer.track(d_real, d_fake) hamming = hamming.data.cpu().numpy()[0] teller_images.extend(image[:4].data.numpy()) drawer_images.extend(fake_image[:4].data.cpu().numpy()) entities = str.join(',', list(batch['entities'][hamming > 0])) added_entities.append(entities) if iteration % self.cfg.vis_rate == 0: visualizer.histogram() self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration) rnn_gradient = np.array(rnn_grads).mean() gru_gradient = np.array(gru_grads).mean() condition_gradient = np.array(condition_encoder_grads).mean() img_encoder_gradient = np.array(img_encoder_grads).mean() rnn_grads, gru_grads = [], [] condition_encoder_grads, img_encoder_grads = [], [] self._plot_gradients(visualizer, rnn_gradient, generator_gradient, discriminator_gradient, gru_gradient, condition_gradient, img_encoder_gradient, iteration) self._draw_images(visualizer, teller_images, drawer_images, nrow=4) self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss) if isinstance(batch['turn'], list): batch['turn'] = np.array(batch['turn']).transpose() visualizer.write(batch['turn'][0]) visualizer.write(added_entities, var_name='entities') teller_images = [] drawer_images = [] if iteration % self.cfg.save_rate == 0: path = os.path.join(self.cfg.log_path, self.cfg.exp_name) self._save(fake_image[:4], path, epoch, iteration) if not self.cfg.debug: self.save_model(path, epoch, iteration) def _forward_generator(self, batch_size, condition, image_feature_maps): noise = torch.FloatTensor(batch_size, self.cfg.noise_dim).normal_(0, 1).cuda() fake_images, mu, logvar, sigma = self.generator(noise, condition, image_feature_maps) return fake_images, mu, logvar, sigma def _optimize_discriminator(self, real_images, fake_images, prev_image, condition, mask, objects, gp_reg=0, aux_reg=0): """Discriminator is updated every step independent of batch_size RNN and the generator """ wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0) wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0) self.discriminator.zero_grad() real_images.requires_grad_() d_real, aux_real, _ = self.discriminator(real_images, condition, prev_image) d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev) d_loss, aux_loss = self._discriminator_masked_loss(d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask) d_loss.backward(retain_graph=True) if gp_reg: reg = gp_reg * self._masked_gradient_penalty(d_real, real_images, mask) reg.backward(retain_graph=True) grad_norm = _recurrent_gan.get_grad_norm(self.discriminator.parameters()) self.discriminator_optimizer.step() d_loss_scalar = d_loss.item() d_real_np = d_real.cpu().data.numpy() d_fake_np = d_fake.cpu().data.numpy() aux_loss_scalar = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss grad_norm_scalar = grad_norm.item() del d_loss del d_real del d_fake del aux_loss del grad_norm gc.collect() return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar def _optimize_generator(self, fake_images, prev_image, condition, objects, aux_reg, mask, mu, logvar): self.generator.zero_grad() d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) g_loss = self._generator_masked_loss(d_fake, aux_fake, objects, aux_reg, mu, logvar, mask) g_loss.backward(retain_graph=True) gen_grad_norm = _recurrent_gan.get_grad_norm(self.generator.parameters()) self.generator_optimizer.step() g_loss_scalar = g_loss.item() gen_grad_norm_scalar = gen_grad_norm.item() del g_loss del gen_grad_norm gc.collect() return g_loss_scalar, gen_grad_norm_scalar def _optimize_rnn(self): torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip) rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters()) self.rnn_optimizer.step() self.rnn.zero_grad() gru_grad_norm = None torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip) gru_grad_norm = _recurrent_gan.get_grad_norm(self.sentence_encoder.parameters()) self.sentence_encoder_optimizer.step() self.sentence_encoder.zero_grad() ce_grad_norm = _recurrent_gan.get_grad_norm(self.condition_encoder.parameters()) ie_grad_norm = _recurrent_gan.get_grad_norm(self.image_encoder.parameters()) self.feature_encoders_optimizer.step() self.condition_encoder.zero_grad() self.image_encoder.zero_grad() return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" d_loss = [] aux_losses = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.discriminator(d_real[b], d_fake[b], d_wrong[b], self.cfg.wrong_fake_ratio) if aux_reg > 0: aux_loss = aux_reg * (self.aux_criterion(aux_real[b], objects[b]).mean() + self.aux_criterion(aux_fake[b], objects[b]).mean()) sample_loss += aux_loss aux_losses.append(aux_loss) d_loss.append(sample_loss) d_loss = torch.stack(d_loss).mean() if len(aux_losses) > 0: aux_losses = torch.stack(aux_losses).mean() else: aux_losses = 0 return d_loss, aux_losses def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu, logvar, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" g_loss = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.generator(d_fake[b]) if aux_reg > 0: aux_loss = aux_reg * self.aux_criterion(aux_fake[b], objects[b]).mean() else: aux_loss = 0 if mu is not None: kl_loss = self.cfg.cond_kl_reg * kl_penalty(mu[b], logvar[b]) else: kl_loss = 0 g_loss.append(sample_loss + aux_loss + kl_loss) g_loss = torch.stack(g_loss) return g_loss.mean() def _masked_gradient_penalty(self, d_real, real_images, mask): gp_reg = gradient_penalty(d_real, real_images).mean() return gp_reg # region Helpers def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration): _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration) def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration): _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration) def _draw_images(self, visualizer, real, fake, nrow): _recurrent_gan.draw_images(self, visualizer, real, fake, nrow) def _save(self, fake, path, epoch, iteration): _recurrent_gan._save(self, fake, path, epoch, iteration) def save_model(self, path, epoch, iteration): _recurrent_gan.save_model(self, path, epoch, iteration) def load_model(self, snapshot_path): _recurrent_gan.load_model(self, snapshot_path)
class RecurrentGAN_Mingyang(): def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN_Mingyang, self).__init__() # region Models-Instantiation ###############################Original DataParallel################### self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # cfg.hidden_dim, # batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() ####################################################################### # self.generator = GeneratorFactory.create_instance(cfg).cuda() # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda() # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda() # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # # cfg.hidden_dim, # # batch_first=False), dim=1).cuda() # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda() # self.image_encoder = =ImageEncoder(cfg).cuda() # self.condition_encoder = ConditionEncoder(cfg).cuda() # self.sentence_encoder = SentenceEncoder(cfg).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() #Added by Mingyang for segmentation loss if cfg.balanced_seg: label_weights = np.array([ 3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04, 1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02, 1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02, 1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03, 3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01, 0.00000000e+00, 8.44408475e-05 ]) #reverse the loss label_weights = 1 / label_weights label_weights[20] = 0 label_weights = label_weights / np.min(label_weights[:20]) #convert numpy to tensor label_weights = torch.from_numpy(label_weights) label_weights = label_weights.type(torch.FloatTensor) self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss(weight=label_weights)).cuda() else: self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name) # define unorm self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # define the label distribution def train_batch(self, batch, epoch, iteration, visualizer, logger, total_iters=0, current_batch_t=0): """ The training scheme follows the following: - Discriminator and Generator is updated every time step. - RNN, SentenceEncoder and ImageEncoder parameters are updated every sequence """ batch_size = len(batch['image']) max_seq_len = batch['image'].size(1) prev_image = torch.FloatTensor(batch['background']) prev_image = prev_image \ .repeat(batch_size, 1, 1, 1) disc_prev_image = prev_image # print("disc_prev_image size is: {}".format(disc_prev_image.shape)) # Initial inputs for the RNN set to zeros hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim) prev_objects = torch.zeros(batch_size, self.cfg.num_objects) teller_images = [] drawer_images = [] added_entities = [] #print("max sequence length of current batch: {}".format(max_seq_len)) for t in range(max_seq_len): image = batch['image'][:, t] turns_word_embedding = batch['turn_word_embedding'][:, t] turns_lengths = batch['turn_lengths'][:, t] objects = batch['objects'][:, t] seq_ended = t > (batch['dialog_length'] - 1) image_feature_map, image_vec, object_detections = \ self.image_encoder(prev_image) _, current_image_feat, _ = self.image_encoder(image) # print("[image_encoder] image_feature_map shape is: {}".format(image_feature_map.shape)) # print("[image_encoder] image_vec shape is: {}".format(image_vec.shape)) turn_embedding = self.sentence_encoder(turns_word_embedding, turns_lengths) rnn_condition, current_image_feat = \ self.condition_encoder(turn_embedding, image_vec, current_image_feat) rnn_condition = rnn_condition.unsqueeze(0) # self.rnn.flatten_parameters() # Added by Mingyang to Resolve the # Warning self.rnn.module.flatten_parameters() output, hidden = self.rnn(rnn_condition, hidden) output = output.squeeze(0) output = self.layer_norm(output) fake_image, mu, logvar, sigma = self._forward_generator( batch_size, output.detach(), image_feature_map) #print("[image_generator] fake_image size is: {}".format(fake_image.shape)) #print("[image_generator] fake_image_one_pixel is: {}".format(fake_image[0,:,0,0])) visualizer.track_sigma(sigma) hamming = objects - prev_objects hamming = torch.clamp(hamming, min=0) # print(image.shape) # print(disc_prev_image.shape) d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \ self._optimize_discriminator(image, fake_image.detach(), disc_prev_image, output, seq_ended, hamming, self.cfg.gp_reg, self.cfg.aux_reg) # append the segmentation loss accordingly if re.search(r"seg", self.cfg.gan_type): assert self.cfg.seg_reg > 0, "the sge_reg must be larger than 0" if self.cfg.gan_type == "recurrent_gan_mingyang_img64_seg": #The size of seg_fake is adjusted to (Batch, N, C) seg_fake = fake_image.view(fake_image.size(0), fake_image.size(1), -1).permute(0, 2, 1) #The size of the seg_gt is obtained from image seg_gt = torch.argmax(image, dim=1).view(image.size(0), -1) else: assert self.cfg.seg_reg == 0, "the sge_reg must be equal to 0" seg_fake = None seg_gt = None g_loss, generator_gradient = \ self._optimize_generator(fake_image, disc_prev_image.detach(), output.detach(), objects, self.cfg.aux_reg, seq_ended, mu, logvar, self.cfg.seg_reg, seg_fake, seg_gt) #return if self.cfg.teacher_forcing: prev_image = image else: prev_image = fake_image disc_prev_image = image prev_objects = objects if (t + 1) % 2 == 0: prev_image = prev_image.detach() rnn_grads = [] gru_grads = [] condition_encoder_grads = [] img_encoder_grads = [] if t == max_seq_len - 1: rnn_gradient, gru_gradient, condition_gradient,\ img_encoder_gradient = self._optimize_rnn() rnn_grads.append(rnn_gradient.data.cpu().numpy()) gru_grads.append(gru_gradient.data.cpu().numpy()) condition_encoder_grads.append( condition_gradient.data.cpu().numpy()) if self.use_image_encoder: img_encoder_grads.append( img_encoder_gradient.data.cpu().numpy()) visualizer.track(d_real, d_fake) hamming = hamming.data.cpu().numpy()[0] # teller_images.extend(image[:4].data.cpu().numpy()) # drawer_images.extend(fake_image[:4].data.cpu().numpy()) new_teller_images = [] for x in image[:4].data.cpu(): # print(x.shape) # new_x = self.unorm(x) # new_x = transforms.ToPILImage()(new_x).convert('RGB') # # new_x = np.array(new_x)[..., ::-1] # new_x = np.moveaxis(np.array(new_x), -1, 0) if self.cfg.image_gen_mode == "real": new_x = self.unormalize(x) elif self.cfg.image_gen_mode == "segmentation": new_x = self.unormalize_segmentation(x.data.numpy()) elif self.cfg.image_gen_mode == "segmentation_onehot": #TODO: Implement the functino to convert new_x to colored_image new_x = self.unormalize_segmentation_onehot( x.data.cpu().numpy()) #print(new_x.shape) #return # print(new_x.shape) new_teller_images.append(new_x) teller_images.extend(new_teller_images) new_drawer_images = [] for x in fake_image[:4].data.cpu(): # print(x.shape) # new_x = self.unorm(x) # new_x = transforms.ToPILImage()(new_x).convert('RGB') # # new_x = np.array(new_x)[..., ::-1] # new_x = np.moveaxis(np.array(new_x), -1, 0) if self.cfg.image_gen_mode == "real": new_x = self.unormalize(x) elif self.cfg.image_gen_mode == "segmentation": new_x = self.unormalize_segmentation(x.data.cpu().numpy()) elif self.cfg.image_gen_mode == "segmentation_onehot": #TODO: Implement the functino to convert new_x to colored_image new_x = self.unormalize_segmentation_onehot( x.data.cpu().numpy()) # print(new_x.shape) new_drawer_images.append(new_x) drawer_images.extend(new_drawer_images) # drawer_images.extend(fake_image[:4].data.cpu().numpy()) # print(drawer_images[0].shape) # entities = str.join(',', list(batch['entities'][hamming > 0])) # added_entities.append(entities) # print(iteration) if iteration % self.cfg.vis_rate == 0: visualizer.histogram() self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration) rnn_gradient = np.array(rnn_grads).mean() gru_gradient = np.array(gru_grads).mean() condition_gradient = np.array(condition_encoder_grads).mean() img_encoder_gradient = np.array(img_encoder_grads).mean() rnn_grads, gru_grads = [], [] condition_encoder_grads, img_encoder_grads = [], [] self._plot_gradients(visualizer, rnn_gradient, generator_gradient, discriminator_gradient, gru_gradient, condition_gradient, img_encoder_gradient, iteration) self._draw_images(visualizer, teller_images, drawer_images, nrow=4) # self.logger.write(epoch, "{}/{}".format(iteration,total_iters), # d_real, d_fake, d_loss, g_loss) remaining_time = str( datetime.timedelta(seconds=current_batch_t * (total_iters - iteration))) self.logger.write(epoch, "{}/{}".format(iteration, total_iters), d_real, d_fake, d_loss, g_loss, expected_finish_time=remaining_time) if isinstance(batch['turn'], list): batch['turn'] = np.array(batch['turn']).transpose() visualizer.write(batch['turn'][0]) # visualizer.write(added_entities, var_name='entities') teller_images = [] drawer_images = [] if iteration % self.cfg.save_rate == 0: path = os.path.join(self.cfg.log_path, self.cfg.exp_name) # self._save(fake_image[:4], path, epoch, # iteration) if not self.cfg.debug: self.save_model(path, epoch, iteration) def _forward_generator(self, batch_size, condition, image_feature_maps): noise = torch.FloatTensor(batch_size, self.cfg.noise_dim).normal_(0, 1).cuda() fake_images, mu, logvar, sigma = self.generator( noise, condition, image_feature_maps) return fake_images, mu, logvar, sigma def _optimize_discriminator(self, real_images, fake_images, prev_image, condition, mask, objects, gp_reg=0, aux_reg=0): """Discriminator is updated every step independent of batch_size RNN and the generator """ wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0) wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0) self.discriminator.zero_grad() real_images.requires_grad_() d_real, aux_real, _ = self.discriminator(real_images, condition, prev_image) d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev) d_loss, aux_loss = self._discriminator_masked_loss( d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask) d_loss.backward(retain_graph=True) if gp_reg: reg = gp_reg * self._masked_gradient_penalty( d_real, real_images, mask) reg.backward(retain_graph=True) grad_norm = _recurrent_gan.get_grad_norm( self.discriminator.parameters()) self.discriminator_optimizer.step() d_loss_scalar = d_loss.item() d_real_np = d_real.cpu().data.numpy() d_fake_np = d_fake.cpu().data.numpy() aux_loss_scalar = aux_loss.item() if isinstance( aux_loss, torch.Tensor) else aux_loss grad_norm_scalar = grad_norm.item() del d_loss del d_real del d_fake del aux_loss del grad_norm gc.collect() return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar def _optimize_generator(self, fake_images, prev_image, condition, objects, aux_reg, mask, mu, logvar, seg_reg=0, seg_fake=None, seg_gt=None): self.generator.zero_grad() d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) g_loss = self._generator_masked_loss(d_fake, aux_fake, objects, aux_reg, mu, logvar, mask, seg_reg, seg_fake, seg_gt) g_loss.backward(retain_graph=True) gen_grad_norm = _recurrent_gan.get_grad_norm( self.generator.parameters()) self.generator_optimizer.step() g_loss_scalar = g_loss.item() gen_grad_norm_scalar = gen_grad_norm.item() del g_loss del gen_grad_norm gc.collect() return g_loss_scalar, gen_grad_norm_scalar def _optimize_rnn(self): torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip) rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters()) self.rnn_optimizer.step() self.rnn.zero_grad() gru_grad_norm = None torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip) gru_grad_norm = _recurrent_gan.get_grad_norm( self.sentence_encoder.parameters()) self.sentence_encoder_optimizer.step() self.sentence_encoder.zero_grad() ce_grad_norm = _recurrent_gan.get_grad_norm( self.condition_encoder.parameters()) ie_grad_norm = _recurrent_gan.get_grad_norm( self.image_encoder.parameters()) self.feature_encoders_optimizer.step() self.condition_encoder.zero_grad() self.image_encoder.zero_grad() return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" d_loss = [] aux_losses = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.discriminator( d_real[b], d_fake[b], d_wrong[b], self.cfg.wrong_fake_ratio) if aux_reg > 0: aux_loss = aux_reg * ( self.aux_criterion(aux_real[b], objects[b]).mean() + self.aux_criterion(aux_fake[b], objects[b]).mean()) sample_loss += aux_loss aux_losses.append(aux_loss) d_loss.append(sample_loss) d_loss = torch.stack(d_loss).mean() if len(aux_losses) > 0: aux_losses = torch.stack(aux_losses).mean() else: aux_losses = 0 return d_loss, aux_losses def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu, logvar, mask, seg_reg=0, seg_fake=None, seg_gt=None): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding Append the segmentation loss to the model. seg_fake: (1*C*H*W) seg_gt: (1*H*W) """ g_loss = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.generator(d_fake[b]) if aux_reg > 0: aux_loss = aux_reg * \ self.aux_criterion(aux_fake[b], objects[b]).mean() else: aux_loss = 0 if mu is not None: kl_loss = self.cfg.cond_kl_reg * \ kl_penalty(mu[b], logvar[b]) else: kl_loss = 0 #Append a seg_loss to the total generator loss if seg_reg > 0: #TODO: Implement the Segmentation Loss here seg_loss = seg_reg * self.seg_criterion( seg_fake[b], seg_gt[b] ) #By default it should just give a mean number #print(seg_loss) else: seg_loss = 0 g_loss.append(sample_loss + aux_loss + kl_loss + seg_loss) g_loss = torch.stack(g_loss) return g_loss.mean() def _masked_gradient_penalty(self, d_real, real_images, mask): gp_reg = gradient_penalty(d_real, real_images).mean() return gp_reg # region Helpers def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration): _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration) def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration): _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration) def _draw_images(self, visualizer, real, fake, nrow): _recurrent_gan.draw_images_gandraw(self, visualizer, real, fake, nrow) # Changed by Mingyang Zhou def _save(self, fake, path, epoch, iteration): _recurrent_gan._save(self, fake, path, epoch, iteration) def save_model(self, path, epoch, iteration): _recurrent_gan.save_model(self, path, epoch, iteration) def load_model(self, snapshot_path): _recurrent_gan.load_model(self, snapshot_path) def unormalize(self, x): """ unormalize the image """ new_x = self.unorm(x) new_x = transforms.ToPILImage()(new_x).convert('RGB') # new_x = np.array(new_x)[..., ::-1] new_x = np.moveaxis(np.array(new_x), -1, 0) return new_x def unormalize_segmentation(self, x): new_x = (x + 1) * 127.5 # new_x = new_x.transpose(1, 2, 0)[..., ::-1] return new_x def unormalize_segmentation_onehot(self, x): """ Convert the segmentation into image """ LABEL2COLOR = { 0: { "name": "sky", "color": np.array([134, 193, 46]) }, 1: { "name": "dirt", "color": np.array([30, 22, 100]) }, 2: { "name": "gravel", "color": np.array([163, 164, 153]) }, 3: { "name": "mud", "color": np.array([35, 90, 74]) }, 4: { "name": "sand", "color": np.array([196, 15, 241]) }, 5: { "name": "clouds", "color": np.array([198, 182, 115]) }, 6: { "name": "fog", "color": np.array([76, 60, 231]) }, 7: { "name": "hill", "color": np.array([190, 128, 82]) }, 8: { "name": "mountain", "color": np.array([122, 101, 17]) }, 9: { "name": "river", "color": np.array([97, 140, 33]) }, 10: { "name": "rock", "color": np.array([90, 90, 81]) }, 11: { "name": "sea", "color": np.array([255, 252, 51]) }, 12: { "name": "snow", "color": np.array([51, 255, 252]) }, 13: { "name": "stone", "color": np.array([106, 107, 97]) }, 14: { "name": "water", "color": np.array([0, 255, 0]) }, 15: { "name": "bush", "color": np.array([204, 113, 46]) }, 16: { "name": "flower", "color": np.array([0, 0, 255]) }, 17: { "name": "grass", "color": np.array([255, 0, 0]) }, 18: { "name": "straw", "color": np.array([255, 51, 252]) }, 19: { "name": "tree", "color": np.array([255, 51, 175]) }, 20: { "name": "wood", "color": np.array([66, 18, 120]) }, 21: { "name": "road", "color": np.array([255, 255, 0]) }, } seg_map = np.argmax(x, axis=0) new_x = np.zeros((3, seg_map.shape[0], seg_map.shape[1]), dtype=np.uint8) for i in range(seg_map.shape[0]): for j in range(seg_map.shape[1]): new_x[:, i, j] = LABEL2COLOR[seg_map[i, j]]["color"] return new_x