def __init__(self, cfg): img_path = os.path.join(cfg.log_path, cfg.exp_name, 'train_images_*') if glob.glob(img_path): raise Exception('all directories with name train_images_* under ' 'the experiment directory need to be removed') path = os.path.join(cfg.log_path, cfg.exp_name) self.model = MODELS[cfg.gan_type](cfg) if cfg.load_snapshot is not None: self.model.load_model(cfg.load_snapshot) print('Load model:', cfg.load_snapshot) self.model.save_model(path, 0, 0) shuffle = cfg.gan_type != 'recurrent_gan' self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset], cfg=cfg, img_size=cfg.img_size) self.dataloader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) if cfg.dataset == 'codraw': self.dataloader.collate_fn = codraw_dataset.collate_data elif cfg.dataset == 'iclevr': self.dataloader.collate_fn = clevr_dataset.collate_data self.visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server) self.logger = Logger(cfg.log_path, cfg.exp_name) self.cfg = cfg
def parse_config(): parser = argparse.ArgumentParser(fromfile_prefix_chars='@') # Integers parser.add_argument('--batch_size', type=int, default=64, help="Batch size used in training. Default: 64") parser.add_argument('--conditioning_dim', type=int, default=128, help='Dimensionality of the projected text \ representation after the conditional augmentation. \ Default: 128') parser.add_argument('--disc_cond_channels', type=int, default=256, help='For `conditioning` == concat, this flag' 'decides the number of channels to be concatenated' 'Default: 256') parser.add_argument('--embedding_dim', type=int, default=1024, help='The dimensionality of the text representation. \ Default: 1024') parser.add_argument('--epochs', type=int, default=300, help='Number of epochs for the experiment.\ Default: 300') parser.add_argument('--hidden_dim', type=int, default=256, help='Dimensionality of the RNN hidden state which is' 'used as condition for the generator. Default: 256') parser.add_argument('--img_size', type=int, default=128, help='Image size to use for training. \ Options = {128}') parser.add_argument('--image_feat_dim', type=int, default=512, help='image encoding number of channels for the' 'recurrent setup. Default: 512') parser.add_argument('--input_dim', type=int, default=1024, help='RNN condition dimension, the dimensionality of' 'the image encoder and question projector as well.') parser.add_argument('--inception_count', type=int, default=5000, help='Number of images to use for inception score.') parser.add_argument('--noise_dim', type=int, default=100, help="Dimensionality of the noise vector that is used\ as input to the generator: Default: 100") parser.add_argument('--num_workers', type=int, default=8, help="Degree of parallelism to use. Default: 8") parser.add_argument('--num_objects', type=int, default=58, help='Number of object the auxiliary objective of ' 'object detection is trained on') parser.add_argument('--projected_text_dim', type=int, default=1024, help='Pre-fusion text projection dimension for the' 'recurrent setup. Default: 1024') parser.add_argument('--save_rate', type=int, default=5000, help='Number of iterations between saving current\ model sample generations. Default: 5000') parser.add_argument('--sentence_embedding_dim', type=int, default=1024, help='Dimensionality of the sentence encoding training' 'on top of glove word embedding. Default: 1024') parser.add_argument('--vis_rate', type=int, default=20, help='Number of iterations between each visualization.\ Default: 20') # Floats parser.add_argument('--aux_reg', type=float, default=5, help='Weighting factor for the aux loss. Default: 5') parser.add_argument('--cond_kl_reg', type=float, default=None, help='CA Net KL penalty regularization weighting' 'factor. Default: None, means no regularization.') parser.add_argument('--discriminator_lr', type=float, default=0.0004, help="Learning rate used for optimizing the \ discriminator. Default = 0.0004") parser.add_argument('--discriminator_beta1', type=float, default=0, help="Beta1 value for Adam optimizers. Default = 0") parser.add_argument('--discriminator_beta2', type=float, default=0.9, help="Beta2 value for Adam optimizer. Default = 0.9") parser.add_argument('--discriminator_weight_decay', type=float, default=0, help='Weight decay for the discriminator. Default= 0') parser.add_argument('--feature_encoder_lr', type=float, default=2e-3, help='Learning rate for image encoder and condition' 'encoder. Default= 2e-3') parser.add_argument('--generator_lr', type=float, default=0.0001, help="Learning rate used for optimizing the generator \ Default = 0.0001") parser.add_argument('--generator_beta1', type=float, default=0, help="Beta1 value for Adam optimizers. Default = 0") parser.add_argument('--generator_beta2', type=float, default=0.9, help="Beta2 value for Adam optimizer. Default = 0.9") parser.add_argument('--generator_weight_decay', type=float, default=0, help='Weight decay for the generator. Default= 0') parser.add_argument('--gp_reg', type=float, default=None, help='Gradient penalty regularization weighting' 'factor. Default: None, means no regularization.') parser.add_argument('--grad_clip', type=float, default=4, help='Gradient clipping threshold for RNN and GRU.' 'Default: 4') parser.add_argument('--gru_lr', type=float, default=0.0001, help='Sentence encoder optimizer learning rate') parser.add_argument('--rnn_lr', type=float, default=0.0005, help='RNN optimizer learning rate') parser.add_argument('--wrong_fake_ratio', type=float, default=0.5, help='Ratio of wrong:fake losses.' 'Default: 0.5') # Strings parser.add_argument('--activation', type=str, default='relu', help='Activation function to use.' 'Options = [relu, leaky_relu, selu]') parser.add_argument('--arch', type=str, default='resblocks', help='Network Architecture to use. Two options are' 'available {resblocks}') parser.add_argument('--conditioning', type=str, default=None, help='Method of Conditioning text. Default is None.' 'Options: {concat, projection}') parser.add_argument('--condition_encoder_optimizer', type=str, default='adam', help='Image encoder and text projection optimizer') parser.add_argument('--criterion', type=str, default='hinge', help='Loss function to use. Options:' '{classical, hinge} Default: hinge') parser.add_argument('--dataset', type=str, default='codraw', help='Dataset to use for training. \ Options = {codraw, iclevr}') parser.add_argument( '--discriminator_optimizer', type=str, default='adam', help="Optimizer used while training the discriminator. \ Default: Adam") parser.add_argument('--disc_img_conditioning', type=str, default='subtract', help='Image conditioning for discriminator, either' 'channel subtraction or concatenation.' 'Options = {concat, subtract}' 'Default: subtract') parser.add_argument('--exp_name', type=str, default='TellDrawRepeat', help='Experiment name that will be used for' 'visualization and Logging. Default: TellDrawRepeat') parser.add_argument('--embedding_type', type=str, default='gru', help='Type of sentence encoding. Train a GRU' 'over word embedding with \'gru\' option.' 'Default: gru') parser.add_argument('--gan_type', type=str, default='recurrent_gan', help='Gan type: recurrent. Options:' '[recurrent_gan]. Default: recurrent_gan') parser.add_argument('--generator_optimizer', type=str, default='adam', help="Optimizer used while training the generator. \ Default: Adam") parser.add_argument('--gen_fusion', type=str, default='concat', help='Method to use when fusing the image features in' 'the generator. options = [concat, gate]') parser.add_argument('--gru_optimizer', type=str, default='rmsprop', help='Sentence encoder optimizer type') parser.add_argument('--img_encoder_type', type=str, default='res_blocks', help='Building blocks of the image encoder.' ' Default: res_blocks') parser.add_argument('--load_snapshot', type=str, default=None, help='Snapshot file to load model and optimizer' 'state from') parser.add_argument('--log_path', type=str, default='logs', help='Path where to save logs and image generations.') parser.add_argument('--results_path', type=str, default='results/', help='Path where to save the generated samples') parser.add_argument('--rnn_optimizer', type=str, default='rmsprop', help='Optimizer to use for RNN') parser.add_argument('--test_dataset', type=str, help='Test dataset path key.') parser.add_argument('--val_dataset', type=str, help='Validation dataset path key.') parser.add_argument('--vis_server', type=str, default='http://localhost', help='Visdom server address') # Boolean parser.add_argument('-debug', action='store_true', help='Debugging flag.(e.g, Do not save weights)') parser.add_argument('-disc_sn', action='store_true', help='A flag that decides whether to use spectral norm' 'in the discriminator') parser.add_argument('-generator_sn', action='store_true', help='A flag that decides whether to use' 'spectral norm in the generator.') parser.add_argument('-generator_optim_image', action='store_true', help='A flag of whether to optimize the image encoder' 'w.r.t the generator.') parser.add_argument('-inference_save_last_only', action='store_true', help='A flag that decides whether to only' 'save the last image for each dialogue.') parser.add_argument('-metric_inception_objects', action='store_true', help='A flag that decides whether to evaluate & report' 'object detection accuracy') parser.add_argument('-teacher_forcing', action='store_true', help='a flag to indicate to whether to train using' 'teacher_forcing. NOTE: With teacher_forcing=False' 'more GPU memory will be used.') parser.add_argument('-self_attention', action='store_true', help='A flag that decides whether to use' 'self-attention layers.') parser.add_argument('-skip_connect', action='store_true', help='A flag that decides whether to have a skip' 'connection between the GRU output and the LSTM input') parser.add_argument('-use_fd', action='store_true', help='a flag which decide whether to use image' 'features conditioning in the discriminator.') parser.add_argument('-use_fg', action='store_true', help='a flag which decide whether to use image' 'features conditioning in the generator.') args = parser.parse_args() logger = Logger(args.log_path, args.exp_name) logger.write_config(str(args)) return args
class Trainer(): def __init__(self, cfg): img_path = os.path.join(cfg.log_path, cfg.exp_name, 'train_images_*') if glob.glob(img_path): raise Exception('all directories with name train_images_* under ' 'the experiment directory need to be removed') path = os.path.join(cfg.log_path, cfg.exp_name) self.model = MODELS[cfg.gan_type](cfg) if cfg.load_snapshot is not None: self.model.load_model(cfg.load_snapshot) print('Load model:', cfg.load_snapshot) self.model.save_model(path, 0, 0) shuffle = cfg.gan_type != 'recurrent_gan' self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset], cfg=cfg, img_size=cfg.img_size) self.dataloader = DataLoader(self.dataset, batch_size=cfg.batch_size, shuffle=shuffle, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) if cfg.dataset == 'codraw': self.dataloader.collate_fn = codraw_dataset.collate_data elif cfg.dataset == 'iclevr': self.dataloader.collate_fn = clevr_dataset.collate_data self.visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server) self.logger = Logger(cfg.log_path, cfg.exp_name) self.cfg = cfg def train(self): iteration_counter = 0 for epoch in tqdm(range(self.cfg.epochs), ascii=True): if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) res = evaluator.evaluate(iteration_counter) print('\nIter %d:' % (iteration_counter)) print(res) self.logger.write_res(iteration_counter, res) del evaluator iteration_counter += 1 self.model.train_batch(batch, epoch, iteration_counter, self.visualizer, self.logger) def train_with_ctr(self): cfg = self.cfg if cfg.dataset == 'codraw': self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt')) elif cfg.dataset == 'iclevr': self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt')) iteration_counter = 0 for epoch in tqdm(range(self.cfg.epochs), ascii=True): if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0: torch.cuda.empty_cache() evaluator = Evaluator.factory(self.cfg, self.visualizer, self.logger) res = evaluator.evaluate(iteration_counter) print('\nIter %d:' % (iteration_counter)) print(res) self.logger.write_res(iteration_counter, res) del evaluator iteration_counter += 1 self.model.train_batch_with_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger) def train_ctr(self): iteration_counter = 0 with tqdm(range(self.cfg.epochs), ascii=True) as TQ: for epoch in TQ: if cfg.dataset == 'codraw': self.dataset.shuffle() for batch in self.dataloader: loss = self.model.train_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger) TQ.set_postfix(ls_bh=loss) if iteration_counter>0 and (iteration_counter%self.cfg.save_rate)==0: torch.cuda.empty_cache() print('Iter %d: %f' % (iteration_counter, loss)) loss = self.eval_ctr(epoch, iteration_counter) print('Eval: %f' % (loss)) print('') self.logger.write_res(iteration_counter, loss) iteration_counter += 1 def eval_ctr(self, epoch, iteration_counter): cfg = self.cfg dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size) dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) if cfg.dataset == 'codraw': dataloader.collate_fn = codraw_dataset.collate_data elif cfg.dataset == 'iclevr': dataloader.collate_fn = clevr_dataset.collate_data if cfg.dataset == 'codraw': dataset.shuffle() rec_loss = [] for batch in dataloader: loss = self.model.train_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger, is_eval=True) rec_loss.append(loss) loss = np.average(rec_loss) return loss def infer_ctr(self): cfg = self.cfg if cfg.dataset == 'codraw': self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt')) elif cfg.dataset == 'iclevr': self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt')) dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size) dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) if cfg.dataset == 'codraw': dataloader.collate_fn = codraw_dataset.collate_data elif cfg.dataset == 'iclevr': dataloader.collate_fn = clevr_dataset.collate_data glove_key = list(dataset.glove_key.keys()) for batch in dataloader: rec_out, loss = self.model.train_ctr(batch, -1, -1, self.visualizer, self.logger, is_eval=True, is_infer=True) rec_out = np.argmax(rec_out, axis=3) os.system('mkdir ins_result') for i in range(30): os.system('mkdir ins_result/%d' % (i)) F = open('ins_result/%d/ins.txt' % (i), 'w') for j in range(rec_out.shape[1]): print([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])]) print([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])]) print() F.write(' '.join([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])])) F.write('\n') F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])])) F.write('\n') F.write('\n') TV.utils.save_image(batch['image'][i, j].data, 'ins_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1)) print('\n----------------\n') F.close() break os.system('tar zcvf ins_result.tar.gz ins_result') def infer_gen(self): cfg = self.cfg if cfg.dataset == 'codraw': self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0.pt')) elif cfg.dataset == 'iclevr': self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0.pt')) dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size) dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True) if cfg.dataset == 'codraw': dataloader.collate_fn = codraw_dataset.collate_data elif cfg.dataset == 'iclevr': dataloader.collate_fn = clevr_dataset.collate_data glove_key = list(dataset.glove_key.keys()) for batch in dataloader: rec_out = self.model.infer_gen(batch) os.system('mkdir gen_result') for i in range(30): os.system('mkdir gen_result/%d' % (i)) F = open('gen_result/%d/ins.txt' % (i), 'w') for j in range(rec_out.shape[1]): F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(batch['turn_word'].shape[2])])) F.write('\n') TV.utils.save_image(batch['image'][i, j].data, 'gen_result/%d/_%d.png' % (i, j), normalize=True, range=(-1, 1)) TV.utils.save_image(torch.from_numpy(rec_out[i, j]).data, 'gen_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1)) F.close() break os.system('tar zcvf gen_result.tar.gz gen_result')
def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN, self).__init__() # region Models-Instantiation self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer]( self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr ) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name)
class RecurrentGAN(): def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN, self).__init__() # region Models-Instantiation self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer]( self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr ) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name) def train_batch(self, batch, epoch, iteration, visualizer, logger): """ The training scheme follows the following: - Discriminator and Generator is updated every time step. - RNN, SentenceEncoder and ImageEncoder parameters are updated every sequence """ batch_size = len(batch['image']) max_seq_len = batch['image'].size(1) prev_image = torch.FloatTensor(batch['background']) prev_image = prev_image.unsqueeze(0) \ .repeat(batch_size, 1, 1, 1) disc_prev_image = prev_image # Initial inputs for the RNN set to zeros hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim) prev_objects = torch.zeros(batch_size, self.cfg.num_objects) teller_images = [] drawer_images = [] added_entities = [] for t in range(max_seq_len): image = batch['image'][:, t] turns_word_embedding = batch['turn_word_embedding'][:, t] turns_lengths = batch['turn_lengths'][:, t] objects = batch['objects'][:, t] seq_ended = t > (batch['dialog_length'] - 1) image_feature_map, image_vec, object_detections = \ self.image_encoder(prev_image) _, current_image_feat, _ = self.image_encoder(image) turn_embedding = self.sentence_encoder(turns_word_embedding, turns_lengths) rnn_condition, current_image_feat = \ self.condition_encoder(turn_embedding, image_vec, current_image_feat) rnn_condition = rnn_condition.unsqueeze(0) output, hidden = self.rnn(rnn_condition, hidden) output = output.squeeze(0) output = self.layer_norm(output) fake_image, mu, logvar, sigma = self._forward_generator(batch_size, output.detach(), image_feature_map) visualizer.track_sigma(sigma) hamming = objects - prev_objects hamming = torch.clamp(hamming, min=0) d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \ self._optimize_discriminator(image, fake_image.detach(), disc_prev_image, output, seq_ended, hamming, self.cfg.gp_reg, self.cfg.aux_reg) g_loss, generator_gradient = \ self._optimize_generator(fake_image, disc_prev_image.detach(), output.detach(), objects, self.cfg.aux_reg, seq_ended, mu, logvar) if self.cfg.teacher_forcing: prev_image = image else: prev_image = fake_image disc_prev_image = image prev_objects = objects if (t + 1) % 2 == 0: prev_image = prev_image.detach() rnn_grads = [] gru_grads = [] condition_encoder_grads = [] img_encoder_grads = [] if t == max_seq_len - 1: rnn_gradient, gru_gradient, condition_gradient,\ img_encoder_gradient = self._optimize_rnn() rnn_grads.append(rnn_gradient.data.cpu().numpy()) gru_grads.append(gru_gradient.data.cpu().numpy()) condition_encoder_grads.append(condition_gradient.data.cpu().numpy()) if self.use_image_encoder: img_encoder_grads.append(img_encoder_gradient.data.cpu().numpy()) visualizer.track(d_real, d_fake) hamming = hamming.data.cpu().numpy()[0] teller_images.extend(image[:4].data.numpy()) drawer_images.extend(fake_image[:4].data.cpu().numpy()) entities = str.join(',', list(batch['entities'][hamming > 0])) added_entities.append(entities) if iteration % self.cfg.vis_rate == 0: visualizer.histogram() self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration) rnn_gradient = np.array(rnn_grads).mean() gru_gradient = np.array(gru_grads).mean() condition_gradient = np.array(condition_encoder_grads).mean() img_encoder_gradient = np.array(img_encoder_grads).mean() rnn_grads, gru_grads = [], [] condition_encoder_grads, img_encoder_grads = [], [] self._plot_gradients(visualizer, rnn_gradient, generator_gradient, discriminator_gradient, gru_gradient, condition_gradient, img_encoder_gradient, iteration) self._draw_images(visualizer, teller_images, drawer_images, nrow=4) self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss) if isinstance(batch['turn'], list): batch['turn'] = np.array(batch['turn']).transpose() visualizer.write(batch['turn'][0]) visualizer.write(added_entities, var_name='entities') teller_images = [] drawer_images = [] if iteration % self.cfg.save_rate == 0: path = os.path.join(self.cfg.log_path, self.cfg.exp_name) self._save(fake_image[:4], path, epoch, iteration) if not self.cfg.debug: self.save_model(path, epoch, iteration) def _forward_generator(self, batch_size, condition, image_feature_maps): noise = torch.FloatTensor(batch_size, self.cfg.noise_dim).normal_(0, 1).cuda() fake_images, mu, logvar, sigma = self.generator(noise, condition, image_feature_maps) return fake_images, mu, logvar, sigma def _optimize_discriminator(self, real_images, fake_images, prev_image, condition, mask, objects, gp_reg=0, aux_reg=0): """Discriminator is updated every step independent of batch_size RNN and the generator """ wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0) wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0) self.discriminator.zero_grad() real_images.requires_grad_() d_real, aux_real, _ = self.discriminator(real_images, condition, prev_image) d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev) d_loss, aux_loss = self._discriminator_masked_loss(d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask) d_loss.backward(retain_graph=True) if gp_reg: reg = gp_reg * self._masked_gradient_penalty(d_real, real_images, mask) reg.backward(retain_graph=True) grad_norm = _recurrent_gan.get_grad_norm(self.discriminator.parameters()) self.discriminator_optimizer.step() d_loss_scalar = d_loss.item() d_real_np = d_real.cpu().data.numpy() d_fake_np = d_fake.cpu().data.numpy() aux_loss_scalar = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss grad_norm_scalar = grad_norm.item() del d_loss del d_real del d_fake del aux_loss del grad_norm gc.collect() return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar def _optimize_generator(self, fake_images, prev_image, condition, objects, aux_reg, mask, mu, logvar): self.generator.zero_grad() d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) g_loss = self._generator_masked_loss(d_fake, aux_fake, objects, aux_reg, mu, logvar, mask) g_loss.backward(retain_graph=True) gen_grad_norm = _recurrent_gan.get_grad_norm(self.generator.parameters()) self.generator_optimizer.step() g_loss_scalar = g_loss.item() gen_grad_norm_scalar = gen_grad_norm.item() del g_loss del gen_grad_norm gc.collect() return g_loss_scalar, gen_grad_norm_scalar def _optimize_rnn(self): torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip) rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters()) self.rnn_optimizer.step() self.rnn.zero_grad() gru_grad_norm = None torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip) gru_grad_norm = _recurrent_gan.get_grad_norm(self.sentence_encoder.parameters()) self.sentence_encoder_optimizer.step() self.sentence_encoder.zero_grad() ce_grad_norm = _recurrent_gan.get_grad_norm(self.condition_encoder.parameters()) ie_grad_norm = _recurrent_gan.get_grad_norm(self.image_encoder.parameters()) self.feature_encoders_optimizer.step() self.condition_encoder.zero_grad() self.image_encoder.zero_grad() return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" d_loss = [] aux_losses = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.discriminator(d_real[b], d_fake[b], d_wrong[b], self.cfg.wrong_fake_ratio) if aux_reg > 0: aux_loss = aux_reg * (self.aux_criterion(aux_real[b], objects[b]).mean() + self.aux_criterion(aux_fake[b], objects[b]).mean()) sample_loss += aux_loss aux_losses.append(aux_loss) d_loss.append(sample_loss) d_loss = torch.stack(d_loss).mean() if len(aux_losses) > 0: aux_losses = torch.stack(aux_losses).mean() else: aux_losses = 0 return d_loss, aux_losses def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu, logvar, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" g_loss = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.generator(d_fake[b]) if aux_reg > 0: aux_loss = aux_reg * self.aux_criterion(aux_fake[b], objects[b]).mean() else: aux_loss = 0 if mu is not None: kl_loss = self.cfg.cond_kl_reg * kl_penalty(mu[b], logvar[b]) else: kl_loss = 0 g_loss.append(sample_loss + aux_loss + kl_loss) g_loss = torch.stack(g_loss) return g_loss.mean() def _masked_gradient_penalty(self, d_real, real_images, mask): gp_reg = gradient_penalty(d_real, real_images).mean() return gp_reg # region Helpers def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration): _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration) def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration): _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration) def _draw_images(self, visualizer, real, fake, nrow): _recurrent_gan.draw_images(self, visualizer, real, fake, nrow) def _save(self, fake, path, epoch, iteration): _recurrent_gan._save(self, fake, path, epoch, iteration) def save_model(self, path, epoch, iteration): _recurrent_gan.save_model(self, path, epoch, iteration) def load_model(self, snapshot_path): _recurrent_gan.load_model(self, snapshot_path)
def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN_Mingyang, self).__init__() # region Models-Instantiation ###############################Original DataParallel################### self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # cfg.hidden_dim, # batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() ####################################################################### # self.generator = GeneratorFactory.create_instance(cfg).cuda() # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda() # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda() # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # # cfg.hidden_dim, # # batch_first=False), dim=1).cuda() # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda() # self.image_encoder = =ImageEncoder(cfg).cuda() # self.condition_encoder = ConditionEncoder(cfg).cuda() # self.sentence_encoder = SentenceEncoder(cfg).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() #Added by Mingyang for segmentation loss if cfg.balanced_seg: label_weights = np.array([ 3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04, 1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02, 1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02, 1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03, 3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01, 0.00000000e+00, 8.44408475e-05 ]) #reverse the loss label_weights = 1 / label_weights label_weights[20] = 0 label_weights = label_weights / np.min(label_weights[:20]) #convert numpy to tensor label_weights = torch.from_numpy(label_weights) label_weights = label_weights.type(torch.FloatTensor) self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss(weight=label_weights)).cuda() else: self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name) # define unorm self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
class RecurrentGAN_Mingyang(): def __init__(self, cfg): """A recurrent GAN model, each time step a generated image (x'_{t-1}) and the current question q_{t} are fed to the RNN to produce the conditioning vector for the GAN. The following equations describe this model: - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1}) - x^{~}_{t} = G(z | c_{t}) """ super(RecurrentGAN_Mingyang, self).__init__() # region Models-Instantiation ###############################Original DataParallel################### self.generator = DataParallel( GeneratorFactory.create_instance(cfg)).cuda() self.discriminator = DataParallel( DiscriminatorFactory.create_instance(cfg)).cuda() self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim, cfg.hidden_dim, batch_first=False), dim=1).cuda() # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # cfg.hidden_dim, # batch_first=False), dim=1).cuda() self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda() self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda() self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda() self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda() ####################################################################### # self.generator = GeneratorFactory.create_instance(cfg).cuda() # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda() # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda() # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim, # # cfg.hidden_dim, # # batch_first=False), dim=1).cuda() # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda() # self.image_encoder = =ImageEncoder(cfg).cuda() # self.condition_encoder = ConditionEncoder(cfg).cuda() # self.sentence_encoder = SentenceEncoder(cfg).cuda() # endregion # region Optimizers self.generator_optimizer = OPTIM[cfg.generator_optimizer]( self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1, cfg.generator_beta2, cfg.generator_weight_decay) self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer]( self.discriminator.parameters(), cfg.discriminator_lr, cfg.discriminator_beta1, cfg.discriminator_beta2, cfg.discriminator_weight_decay) self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(), cfg.rnn_lr) self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer]( self.sentence_encoder.parameters(), cfg.gru_lr) self.use_image_encoder = cfg.use_fg feature_encoding_params = list(self.condition_encoder.parameters()) if self.use_image_encoder: feature_encoding_params += list(self.image_encoder.parameters()) self.feature_encoders_optimizer = OPTIM['adam']( feature_encoding_params, cfg.feature_encoder_lr) # endregion # region Criterion self.criterion = LOSSES[cfg.criterion]() self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda() #Added by Mingyang for segmentation loss if cfg.balanced_seg: label_weights = np.array([ 3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04, 1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02, 1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02, 1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03, 3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01, 0.00000000e+00, 8.44408475e-05 ]) #reverse the loss label_weights = 1 / label_weights label_weights[20] = 0 label_weights = label_weights / np.min(label_weights[:20]) #convert numpy to tensor label_weights = torch.from_numpy(label_weights) label_weights = label_weights.type(torch.FloatTensor) self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss(weight=label_weights)).cuda() else: self.seg_criterion = DataParallel( torch.nn.CrossEntropyLoss()).cuda() # endregion self.cfg = cfg self.logger = Logger(cfg.log_path, cfg.exp_name) # define unorm self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # define the label distribution def train_batch(self, batch, epoch, iteration, visualizer, logger, total_iters=0, current_batch_t=0): """ The training scheme follows the following: - Discriminator and Generator is updated every time step. - RNN, SentenceEncoder and ImageEncoder parameters are updated every sequence """ batch_size = len(batch['image']) max_seq_len = batch['image'].size(1) prev_image = torch.FloatTensor(batch['background']) prev_image = prev_image \ .repeat(batch_size, 1, 1, 1) disc_prev_image = prev_image # print("disc_prev_image size is: {}".format(disc_prev_image.shape)) # Initial inputs for the RNN set to zeros hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim) prev_objects = torch.zeros(batch_size, self.cfg.num_objects) teller_images = [] drawer_images = [] added_entities = [] #print("max sequence length of current batch: {}".format(max_seq_len)) for t in range(max_seq_len): image = batch['image'][:, t] turns_word_embedding = batch['turn_word_embedding'][:, t] turns_lengths = batch['turn_lengths'][:, t] objects = batch['objects'][:, t] seq_ended = t > (batch['dialog_length'] - 1) image_feature_map, image_vec, object_detections = \ self.image_encoder(prev_image) _, current_image_feat, _ = self.image_encoder(image) # print("[image_encoder] image_feature_map shape is: {}".format(image_feature_map.shape)) # print("[image_encoder] image_vec shape is: {}".format(image_vec.shape)) turn_embedding = self.sentence_encoder(turns_word_embedding, turns_lengths) rnn_condition, current_image_feat = \ self.condition_encoder(turn_embedding, image_vec, current_image_feat) rnn_condition = rnn_condition.unsqueeze(0) # self.rnn.flatten_parameters() # Added by Mingyang to Resolve the # Warning self.rnn.module.flatten_parameters() output, hidden = self.rnn(rnn_condition, hidden) output = output.squeeze(0) output = self.layer_norm(output) fake_image, mu, logvar, sigma = self._forward_generator( batch_size, output.detach(), image_feature_map) #print("[image_generator] fake_image size is: {}".format(fake_image.shape)) #print("[image_generator] fake_image_one_pixel is: {}".format(fake_image[0,:,0,0])) visualizer.track_sigma(sigma) hamming = objects - prev_objects hamming = torch.clamp(hamming, min=0) # print(image.shape) # print(disc_prev_image.shape) d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \ self._optimize_discriminator(image, fake_image.detach(), disc_prev_image, output, seq_ended, hamming, self.cfg.gp_reg, self.cfg.aux_reg) # append the segmentation loss accordingly if re.search(r"seg", self.cfg.gan_type): assert self.cfg.seg_reg > 0, "the sge_reg must be larger than 0" if self.cfg.gan_type == "recurrent_gan_mingyang_img64_seg": #The size of seg_fake is adjusted to (Batch, N, C) seg_fake = fake_image.view(fake_image.size(0), fake_image.size(1), -1).permute(0, 2, 1) #The size of the seg_gt is obtained from image seg_gt = torch.argmax(image, dim=1).view(image.size(0), -1) else: assert self.cfg.seg_reg == 0, "the sge_reg must be equal to 0" seg_fake = None seg_gt = None g_loss, generator_gradient = \ self._optimize_generator(fake_image, disc_prev_image.detach(), output.detach(), objects, self.cfg.aux_reg, seq_ended, mu, logvar, self.cfg.seg_reg, seg_fake, seg_gt) #return if self.cfg.teacher_forcing: prev_image = image else: prev_image = fake_image disc_prev_image = image prev_objects = objects if (t + 1) % 2 == 0: prev_image = prev_image.detach() rnn_grads = [] gru_grads = [] condition_encoder_grads = [] img_encoder_grads = [] if t == max_seq_len - 1: rnn_gradient, gru_gradient, condition_gradient,\ img_encoder_gradient = self._optimize_rnn() rnn_grads.append(rnn_gradient.data.cpu().numpy()) gru_grads.append(gru_gradient.data.cpu().numpy()) condition_encoder_grads.append( condition_gradient.data.cpu().numpy()) if self.use_image_encoder: img_encoder_grads.append( img_encoder_gradient.data.cpu().numpy()) visualizer.track(d_real, d_fake) hamming = hamming.data.cpu().numpy()[0] # teller_images.extend(image[:4].data.cpu().numpy()) # drawer_images.extend(fake_image[:4].data.cpu().numpy()) new_teller_images = [] for x in image[:4].data.cpu(): # print(x.shape) # new_x = self.unorm(x) # new_x = transforms.ToPILImage()(new_x).convert('RGB') # # new_x = np.array(new_x)[..., ::-1] # new_x = np.moveaxis(np.array(new_x), -1, 0) if self.cfg.image_gen_mode == "real": new_x = self.unormalize(x) elif self.cfg.image_gen_mode == "segmentation": new_x = self.unormalize_segmentation(x.data.numpy()) elif self.cfg.image_gen_mode == "segmentation_onehot": #TODO: Implement the functino to convert new_x to colored_image new_x = self.unormalize_segmentation_onehot( x.data.cpu().numpy()) #print(new_x.shape) #return # print(new_x.shape) new_teller_images.append(new_x) teller_images.extend(new_teller_images) new_drawer_images = [] for x in fake_image[:4].data.cpu(): # print(x.shape) # new_x = self.unorm(x) # new_x = transforms.ToPILImage()(new_x).convert('RGB') # # new_x = np.array(new_x)[..., ::-1] # new_x = np.moveaxis(np.array(new_x), -1, 0) if self.cfg.image_gen_mode == "real": new_x = self.unormalize(x) elif self.cfg.image_gen_mode == "segmentation": new_x = self.unormalize_segmentation(x.data.cpu().numpy()) elif self.cfg.image_gen_mode == "segmentation_onehot": #TODO: Implement the functino to convert new_x to colored_image new_x = self.unormalize_segmentation_onehot( x.data.cpu().numpy()) # print(new_x.shape) new_drawer_images.append(new_x) drawer_images.extend(new_drawer_images) # drawer_images.extend(fake_image[:4].data.cpu().numpy()) # print(drawer_images[0].shape) # entities = str.join(',', list(batch['entities'][hamming > 0])) # added_entities.append(entities) # print(iteration) if iteration % self.cfg.vis_rate == 0: visualizer.histogram() self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration) rnn_gradient = np.array(rnn_grads).mean() gru_gradient = np.array(gru_grads).mean() condition_gradient = np.array(condition_encoder_grads).mean() img_encoder_gradient = np.array(img_encoder_grads).mean() rnn_grads, gru_grads = [], [] condition_encoder_grads, img_encoder_grads = [], [] self._plot_gradients(visualizer, rnn_gradient, generator_gradient, discriminator_gradient, gru_gradient, condition_gradient, img_encoder_gradient, iteration) self._draw_images(visualizer, teller_images, drawer_images, nrow=4) # self.logger.write(epoch, "{}/{}".format(iteration,total_iters), # d_real, d_fake, d_loss, g_loss) remaining_time = str( datetime.timedelta(seconds=current_batch_t * (total_iters - iteration))) self.logger.write(epoch, "{}/{}".format(iteration, total_iters), d_real, d_fake, d_loss, g_loss, expected_finish_time=remaining_time) if isinstance(batch['turn'], list): batch['turn'] = np.array(batch['turn']).transpose() visualizer.write(batch['turn'][0]) # visualizer.write(added_entities, var_name='entities') teller_images = [] drawer_images = [] if iteration % self.cfg.save_rate == 0: path = os.path.join(self.cfg.log_path, self.cfg.exp_name) # self._save(fake_image[:4], path, epoch, # iteration) if not self.cfg.debug: self.save_model(path, epoch, iteration) def _forward_generator(self, batch_size, condition, image_feature_maps): noise = torch.FloatTensor(batch_size, self.cfg.noise_dim).normal_(0, 1).cuda() fake_images, mu, logvar, sigma = self.generator( noise, condition, image_feature_maps) return fake_images, mu, logvar, sigma def _optimize_discriminator(self, real_images, fake_images, prev_image, condition, mask, objects, gp_reg=0, aux_reg=0): """Discriminator is updated every step independent of batch_size RNN and the generator """ wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0) wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0) self.discriminator.zero_grad() real_images.requires_grad_() d_real, aux_real, _ = self.discriminator(real_images, condition, prev_image) d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev) d_loss, aux_loss = self._discriminator_masked_loss( d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask) d_loss.backward(retain_graph=True) if gp_reg: reg = gp_reg * self._masked_gradient_penalty( d_real, real_images, mask) reg.backward(retain_graph=True) grad_norm = _recurrent_gan.get_grad_norm( self.discriminator.parameters()) self.discriminator_optimizer.step() d_loss_scalar = d_loss.item() d_real_np = d_real.cpu().data.numpy() d_fake_np = d_fake.cpu().data.numpy() aux_loss_scalar = aux_loss.item() if isinstance( aux_loss, torch.Tensor) else aux_loss grad_norm_scalar = grad_norm.item() del d_loss del d_real del d_fake del aux_loss del grad_norm gc.collect() return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar def _optimize_generator(self, fake_images, prev_image, condition, objects, aux_reg, mask, mu, logvar, seg_reg=0, seg_fake=None, seg_gt=None): self.generator.zero_grad() d_fake, aux_fake, _ = self.discriminator(fake_images, condition, prev_image) g_loss = self._generator_masked_loss(d_fake, aux_fake, objects, aux_reg, mu, logvar, mask, seg_reg, seg_fake, seg_gt) g_loss.backward(retain_graph=True) gen_grad_norm = _recurrent_gan.get_grad_norm( self.generator.parameters()) self.generator_optimizer.step() g_loss_scalar = g_loss.item() gen_grad_norm_scalar = gen_grad_norm.item() del g_loss del gen_grad_norm gc.collect() return g_loss_scalar, gen_grad_norm_scalar def _optimize_rnn(self): torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip) rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters()) self.rnn_optimizer.step() self.rnn.zero_grad() gru_grad_norm = None torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip) gru_grad_norm = _recurrent_gan.get_grad_norm( self.sentence_encoder.parameters()) self.sentence_encoder_optimizer.step() self.sentence_encoder.zero_grad() ce_grad_norm = _recurrent_gan.get_grad_norm( self.condition_encoder.parameters()) ie_grad_norm = _recurrent_gan.get_grad_norm( self.image_encoder.parameters()) self.feature_encoders_optimizer.step() self.condition_encoder.zero_grad() self.image_encoder.zero_grad() return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg, mask): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding""" d_loss = [] aux_losses = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.discriminator( d_real[b], d_fake[b], d_wrong[b], self.cfg.wrong_fake_ratio) if aux_reg > 0: aux_loss = aux_reg * ( self.aux_criterion(aux_real[b], objects[b]).mean() + self.aux_criterion(aux_fake[b], objects[b]).mean()) sample_loss += aux_loss aux_losses.append(aux_loss) d_loss.append(sample_loss) d_loss = torch.stack(d_loss).mean() if len(aux_losses) > 0: aux_losses = torch.stack(aux_losses).mean() else: aux_losses = 0 return d_loss, aux_losses def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg, mu, logvar, mask, seg_reg=0, seg_fake=None, seg_gt=None): """Accumulates losses only for sequences that have not ended to avoid back-propagation through padding Append the segmentation loss to the model. seg_fake: (1*C*H*W) seg_gt: (1*H*W) """ g_loss = [] for b, ended in enumerate(mask): if not ended: sample_loss = self.criterion.generator(d_fake[b]) if aux_reg > 0: aux_loss = aux_reg * \ self.aux_criterion(aux_fake[b], objects[b]).mean() else: aux_loss = 0 if mu is not None: kl_loss = self.cfg.cond_kl_reg * \ kl_penalty(mu[b], logvar[b]) else: kl_loss = 0 #Append a seg_loss to the total generator loss if seg_reg > 0: #TODO: Implement the Segmentation Loss here seg_loss = seg_reg * self.seg_criterion( seg_fake[b], seg_gt[b] ) #By default it should just give a mean number #print(seg_loss) else: seg_loss = 0 g_loss.append(sample_loss + aux_loss + kl_loss + seg_loss) g_loss = torch.stack(g_loss) return g_loss.mean() def _masked_gradient_penalty(self, d_real, real_images, mask): gp_reg = gradient_penalty(d_real, real_images).mean() return gp_reg # region Helpers def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration): _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration) def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration): _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie, iteration) def _draw_images(self, visualizer, real, fake, nrow): _recurrent_gan.draw_images_gandraw(self, visualizer, real, fake, nrow) # Changed by Mingyang Zhou def _save(self, fake, path, epoch, iteration): _recurrent_gan._save(self, fake, path, epoch, iteration) def save_model(self, path, epoch, iteration): _recurrent_gan.save_model(self, path, epoch, iteration) def load_model(self, snapshot_path): _recurrent_gan.load_model(self, snapshot_path) def unormalize(self, x): """ unormalize the image """ new_x = self.unorm(x) new_x = transforms.ToPILImage()(new_x).convert('RGB') # new_x = np.array(new_x)[..., ::-1] new_x = np.moveaxis(np.array(new_x), -1, 0) return new_x def unormalize_segmentation(self, x): new_x = (x + 1) * 127.5 # new_x = new_x.transpose(1, 2, 0)[..., ::-1] return new_x def unormalize_segmentation_onehot(self, x): """ Convert the segmentation into image """ LABEL2COLOR = { 0: { "name": "sky", "color": np.array([134, 193, 46]) }, 1: { "name": "dirt", "color": np.array([30, 22, 100]) }, 2: { "name": "gravel", "color": np.array([163, 164, 153]) }, 3: { "name": "mud", "color": np.array([35, 90, 74]) }, 4: { "name": "sand", "color": np.array([196, 15, 241]) }, 5: { "name": "clouds", "color": np.array([198, 182, 115]) }, 6: { "name": "fog", "color": np.array([76, 60, 231]) }, 7: { "name": "hill", "color": np.array([190, 128, 82]) }, 8: { "name": "mountain", "color": np.array([122, 101, 17]) }, 9: { "name": "river", "color": np.array([97, 140, 33]) }, 10: { "name": "rock", "color": np.array([90, 90, 81]) }, 11: { "name": "sea", "color": np.array([255, 252, 51]) }, 12: { "name": "snow", "color": np.array([51, 255, 252]) }, 13: { "name": "stone", "color": np.array([106, 107, 97]) }, 14: { "name": "water", "color": np.array([0, 255, 0]) }, 15: { "name": "bush", "color": np.array([204, 113, 46]) }, 16: { "name": "flower", "color": np.array([0, 0, 255]) }, 17: { "name": "grass", "color": np.array([255, 0, 0]) }, 18: { "name": "straw", "color": np.array([255, 51, 252]) }, 19: { "name": "tree", "color": np.array([255, 51, 175]) }, 20: { "name": "wood", "color": np.array([66, 18, 120]) }, 21: { "name": "road", "color": np.array([255, 255, 0]) }, } seg_map = np.argmax(x, axis=0) new_x = np.zeros((3, seg_map.shape[0], seg_map.shape[1]), dtype=np.uint8) for i in range(seg_map.shape[0]): for j in range(seg_map.shape[1]): new_x[:, i, j] = LABEL2COLOR[seg_map[i, j]]["color"] return new_x