예제 #1
0
    def __init__(self, cfg):
        img_path = os.path.join(cfg.log_path,
                                cfg.exp_name,
                                'train_images_*')
        if glob.glob(img_path):
            raise Exception('all directories with name train_images_* under '
                            'the experiment directory need to be removed')
        path = os.path.join(cfg.log_path, cfg.exp_name)

        self.model = MODELS[cfg.gan_type](cfg)
        
        if cfg.load_snapshot is not None:
            self.model.load_model(cfg.load_snapshot)
            print('Load model:', cfg.load_snapshot)
        self.model.save_model(path, 0, 0)
            
        shuffle = cfg.gan_type != 'recurrent_gan'

        self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset], cfg=cfg, img_size=cfg.img_size)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=shuffle,
                                     num_workers=cfg.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        if cfg.dataset == 'codraw':
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            self.dataloader.collate_fn = clevr_dataset.collate_data

        self.visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server)
        self.logger = Logger(cfg.log_path, cfg.exp_name)
        self.cfg = cfg
예제 #2
0
파일: config.py 프로젝트: zmykevin/GanDraw
def parse_config():
    parser = argparse.ArgumentParser(fromfile_prefix_chars='@')

    # Integers

    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help="Batch size used in training. Default: 64")

    parser.add_argument('--conditioning_dim',
                        type=int,
                        default=128,
                        help='Dimensionality of the projected text \
                        representation after the conditional augmentation. \
                        Default: 128')

    parser.add_argument('--disc_cond_channels',
                        type=int,
                        default=256,
                        help='For `conditioning` == concat, this flag'
                        'decides the number of channels to be concatenated'
                        'Default: 256')

    parser.add_argument('--embedding_dim',
                        type=int,
                        default=1024,
                        help='The dimensionality of the text representation. \
                        Default: 1024')

    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        help='Number of epochs for the experiment.\
                        Default: 300')

    parser.add_argument('--hidden_dim',
                        type=int,
                        default=256,
                        help='Dimensionality of the RNN hidden state which is'
                        'used as condition for the generator. Default: 256')

    parser.add_argument('--img_size',
                        type=int,
                        default=128,
                        help='Image size to use for training. \
                        Options = {128}')

    parser.add_argument('--image_feat_dim',
                        type=int,
                        default=512,
                        help='image encoding number of channels for the'
                        'recurrent setup. Default: 512')

    parser.add_argument('--input_dim',
                        type=int,
                        default=1024,
                        help='RNN condition dimension, the dimensionality of'
                        'the image encoder and question projector as well.')

    parser.add_argument('--inception_count',
                        type=int,
                        default=5000,
                        help='Number of images to use for inception score.')

    parser.add_argument('--noise_dim',
                        type=int,
                        default=100,
                        help="Dimensionality of the noise vector that is used\
                        as input to the generator: Default: 100")

    parser.add_argument('--num_workers',
                        type=int,
                        default=8,
                        help="Degree of parallelism to use. Default: 8")

    parser.add_argument('--num_objects',
                        type=int,
                        default=58,
                        help='Number of object the auxiliary objective of '
                        'object detection is trained on')

    parser.add_argument('--projected_text_dim',
                        type=int,
                        default=1024,
                        help='Pre-fusion text projection dimension for the'
                        'recurrent setup. Default: 1024')

    parser.add_argument('--save_rate',
                        type=int,
                        default=5000,
                        help='Number of iterations between saving current\
                        model sample generations. Default: 5000')

    parser.add_argument('--sentence_embedding_dim',
                        type=int,
                        default=1024,
                        help='Dimensionality of the sentence encoding training'
                        'on top of glove word embedding. Default: 1024')

    parser.add_argument('--vis_rate',
                        type=int,
                        default=20,
                        help='Number of iterations between each visualization.\
                        Default: 20')

    # Floats

    parser.add_argument('--aux_reg',
                        type=float,
                        default=5,
                        help='Weighting factor for the aux loss. Default: 5')

    parser.add_argument('--cond_kl_reg',
                        type=float,
                        default=None,
                        help='CA Net KL penalty regularization weighting'
                        'factor. Default: None, means no regularization.')

    parser.add_argument('--discriminator_lr',
                        type=float,
                        default=0.0004,
                        help="Learning rate used for optimizing the \
                        discriminator. Default = 0.0004")

    parser.add_argument('--discriminator_beta1',
                        type=float,
                        default=0,
                        help="Beta1 value for Adam optimizers. Default = 0")

    parser.add_argument('--discriminator_beta2',
                        type=float,
                        default=0.9,
                        help="Beta2 value for Adam optimizer. Default = 0.9")

    parser.add_argument('--discriminator_weight_decay',
                        type=float,
                        default=0,
                        help='Weight decay for the discriminator. Default= 0')

    parser.add_argument('--feature_encoder_lr',
                        type=float,
                        default=2e-3,
                        help='Learning rate for image encoder and condition'
                        'encoder. Default= 2e-3')

    parser.add_argument('--generator_lr',
                        type=float,
                        default=0.0001,
                        help="Learning rate used for optimizing the generator \
                        Default = 0.0001")

    parser.add_argument('--generator_beta1',
                        type=float,
                        default=0,
                        help="Beta1 value for Adam optimizers. Default = 0")

    parser.add_argument('--generator_beta2',
                        type=float,
                        default=0.9,
                        help="Beta2 value for Adam optimizer. Default = 0.9")

    parser.add_argument('--generator_weight_decay',
                        type=float,
                        default=0,
                        help='Weight decay for the generator. Default= 0')

    parser.add_argument('--gp_reg',
                        type=float,
                        default=None,
                        help='Gradient penalty regularization weighting'
                        'factor. Default: None, means no regularization.')

    parser.add_argument('--grad_clip',
                        type=float,
                        default=4,
                        help='Gradient clipping threshold for RNN and GRU.'
                        'Default: 4')

    parser.add_argument('--gru_lr',
                        type=float,
                        default=0.0001,
                        help='Sentence encoder optimizer learning rate')

    parser.add_argument('--rnn_lr',
                        type=float,
                        default=0.0005,
                        help='RNN optimizer learning rate')

    parser.add_argument('--wrong_fake_ratio',
                        type=float,
                        default=0.5,
                        help='Ratio of wrong:fake losses.'
                        'Default: 0.5')
    # Strings

    parser.add_argument('--activation',
                        type=str,
                        default='relu',
                        help='Activation function to use.'
                        'Options = [relu, leaky_relu, selu]')

    parser.add_argument('--arch',
                        type=str,
                        default='resblocks',
                        help='Network Architecture to use. Two options are'
                        'available {resblocks}')

    parser.add_argument('--conditioning',
                        type=str,
                        default=None,
                        help='Method of Conditioning text. Default is None.'
                        'Options: {concat, projection}')

    parser.add_argument('--condition_encoder_optimizer',
                        type=str,
                        default='adam',
                        help='Image encoder and text projection optimizer')

    parser.add_argument('--criterion',
                        type=str,
                        default='hinge',
                        help='Loss function to use. Options:'
                        '{classical, hinge} Default: hinge')

    parser.add_argument('--dataset',
                        type=str,
                        default='codraw',
                        help='Dataset to use for training. \
                        Options = {codraw, iclevr}')

    parser.add_argument(
        '--discriminator_optimizer',
        type=str,
        default='adam',
        help="Optimizer used while training the discriminator. \
                        Default: Adam")

    parser.add_argument('--disc_img_conditioning',
                        type=str,
                        default='subtract',
                        help='Image conditioning for discriminator, either'
                        'channel subtraction or concatenation.'
                        'Options = {concat, subtract}'
                        'Default: subtract')

    parser.add_argument('--exp_name',
                        type=str,
                        default='TellDrawRepeat',
                        help='Experiment name that will be used for'
                        'visualization and Logging. Default: TellDrawRepeat')

    parser.add_argument('--embedding_type',
                        type=str,
                        default='gru',
                        help='Type of sentence encoding. Train a GRU'
                        'over word embedding with \'gru\' option.'
                        'Default: gru')

    parser.add_argument('--gan_type',
                        type=str,
                        default='recurrent_gan',
                        help='Gan type: recurrent. Options:'
                        '[recurrent_gan]. Default: recurrent_gan')

    parser.add_argument('--generator_optimizer',
                        type=str,
                        default='adam',
                        help="Optimizer used while training the generator. \
                        Default: Adam")

    parser.add_argument('--gen_fusion',
                        type=str,
                        default='concat',
                        help='Method to use when fusing the image features in'
                        'the generator. options = [concat, gate]')

    parser.add_argument('--gru_optimizer',
                        type=str,
                        default='rmsprop',
                        help='Sentence encoder optimizer type')

    parser.add_argument('--img_encoder_type',
                        type=str,
                        default='res_blocks',
                        help='Building blocks of the image encoder.'
                        ' Default: res_blocks')

    parser.add_argument('--load_snapshot',
                        type=str,
                        default=None,
                        help='Snapshot file to load model and optimizer'
                        'state from')

    parser.add_argument('--log_path',
                        type=str,
                        default='logs',
                        help='Path where to save logs and image generations.')

    parser.add_argument('--results_path',
                        type=str,
                        default='results/',
                        help='Path where to save the generated samples')

    parser.add_argument('--rnn_optimizer',
                        type=str,
                        default='rmsprop',
                        help='Optimizer to use for RNN')

    parser.add_argument('--test_dataset',
                        type=str,
                        help='Test dataset path key.')

    parser.add_argument('--val_dataset',
                        type=str,
                        help='Validation dataset path key.')

    parser.add_argument('--vis_server',
                        type=str,
                        default='http://localhost',
                        help='Visdom server address')

    # Boolean

    parser.add_argument('-debug',
                        action='store_true',
                        help='Debugging flag.(e.g, Do not save weights)')

    parser.add_argument('-disc_sn',
                        action='store_true',
                        help='A flag that decides whether to use spectral norm'
                        'in the discriminator')

    parser.add_argument('-generator_sn',
                        action='store_true',
                        help='A flag that decides whether to use'
                        'spectral norm in the generator.')

    parser.add_argument('-generator_optim_image',
                        action='store_true',
                        help='A flag of whether to optimize the image encoder'
                        'w.r.t the generator.')

    parser.add_argument('-inference_save_last_only',
                        action='store_true',
                        help='A flag that decides whether to only'
                        'save the last image for each dialogue.')

    parser.add_argument('-metric_inception_objects',
                        action='store_true',
                        help='A flag that decides whether to evaluate & report'
                        'object detection accuracy')

    parser.add_argument('-teacher_forcing',
                        action='store_true',
                        help='a flag to indicate to whether to train using'
                        'teacher_forcing. NOTE: With teacher_forcing=False'
                        'more GPU memory will be used.')

    parser.add_argument('-self_attention',
                        action='store_true',
                        help='A flag that decides whether to use'
                        'self-attention layers.')

    parser.add_argument('-skip_connect',
                        action='store_true',
                        help='A flag that decides whether to have a skip'
                        'connection between the GRU output and the LSTM input')

    parser.add_argument('-use_fd',
                        action='store_true',
                        help='a flag which decide whether to use image'
                        'features conditioning in the discriminator.')

    parser.add_argument('-use_fg',
                        action='store_true',
                        help='a flag which decide whether to use image'
                        'features conditioning in the generator.')

    args = parser.parse_args()

    logger = Logger(args.log_path, args.exp_name)
    logger.write_config(str(args))

    return args
예제 #3
0
class Trainer():
    def __init__(self, cfg):
        img_path = os.path.join(cfg.log_path,
                                cfg.exp_name,
                                'train_images_*')
        if glob.glob(img_path):
            raise Exception('all directories with name train_images_* under '
                            'the experiment directory need to be removed')
        path = os.path.join(cfg.log_path, cfg.exp_name)

        self.model = MODELS[cfg.gan_type](cfg)
        
        if cfg.load_snapshot is not None:
            self.model.load_model(cfg.load_snapshot)
            print('Load model:', cfg.load_snapshot)
        self.model.save_model(path, 0, 0)
            
        shuffle = cfg.gan_type != 'recurrent_gan'

        self.dataset = DATASETS[cfg.dataset](path=keys[cfg.dataset], cfg=cfg, img_size=cfg.img_size)
        self.dataloader = DataLoader(self.dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=shuffle,
                                     num_workers=cfg.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        if cfg.dataset == 'codraw':
            self.dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            self.dataloader.collate_fn = clevr_dataset.collate_data

        self.visualizer = VisdomPlotter(env_name=cfg.exp_name, server=cfg.vis_server)
        self.logger = Logger(cfg.log_path, cfg.exp_name)
        self.cfg = cfg

    def train(self):
        iteration_counter = 0
        for epoch in tqdm(range(self.cfg.epochs), ascii=True):
            if cfg.dataset == 'codraw':
                self.dataset.shuffle()

            for batch in self.dataloader:
                if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0:
                    
                    torch.cuda.empty_cache()
                    evaluator = Evaluator.factory(self.cfg, self.visualizer,
                                                  self.logger)
                    res = evaluator.evaluate(iteration_counter)
                    
                    print('\nIter %d:' % (iteration_counter))
                    print(res)
                    self.logger.write_res(iteration_counter, res)
                    
                    del evaluator

                iteration_counter += 1
                self.model.train_batch(batch,
                                       epoch,
                                       iteration_counter,
                                       self.visualizer,
                                       self.logger)
    
    def train_with_ctr(self):
        
        cfg = self.cfg
        
        if cfg.dataset == 'codraw':
            self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt'))
        elif cfg.dataset == 'iclevr':
            self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt'))
        
        iteration_counter = 0
        for epoch in tqdm(range(self.cfg.epochs), ascii=True):
            if cfg.dataset == 'codraw':
                self.dataset.shuffle()

            for batch in self.dataloader:
                if iteration_counter >= 0 and iteration_counter % self.cfg.save_rate == 0:
                    
                    torch.cuda.empty_cache()
                    evaluator = Evaluator.factory(self.cfg, self.visualizer,
                                                  self.logger)
                    res = evaluator.evaluate(iteration_counter)
                    
                    print('\nIter %d:' % (iteration_counter))
                    print(res)
                    self.logger.write_res(iteration_counter, res)
                    
                    del evaluator

                iteration_counter += 1
                self.model.train_batch_with_ctr(batch,
                                                epoch,
                                                iteration_counter,
                                                self.visualizer,
                                                self.logger)
    
    
    def train_ctr(self):
        iteration_counter = 0
        
        with tqdm(range(self.cfg.epochs), ascii=True) as TQ:
            for epoch in TQ:
                if cfg.dataset == 'codraw':
                    self.dataset.shuffle()

                for batch in self.dataloader:
                    
                    loss = self.model.train_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger)
                    TQ.set_postfix(ls_bh=loss)
                
                    if iteration_counter>0 and (iteration_counter%self.cfg.save_rate)==0:
                        torch.cuda.empty_cache()
                        
                        print('Iter %d: %f' % (iteration_counter, loss))
                        loss = self.eval_ctr(epoch, iteration_counter)
                        print('Eval: %f' % (loss))
                        print('')
                        self.logger.write_res(iteration_counter, loss)
                    
                    iteration_counter += 1
    
    def eval_ctr(self, epoch, iteration_counter):
        cfg = self.cfg
        
        dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
        dataloader = DataLoader(dataset, 
                                batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
        
        if cfg.dataset == 'codraw':
            dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            dataloader.collate_fn = clevr_dataset.collate_data
        
        if cfg.dataset == 'codraw':
            dataset.shuffle()
        rec_loss = []
        for batch in dataloader:
            loss = self.model.train_ctr(batch, epoch, iteration_counter, self.visualizer, self.logger, is_eval=True)
            rec_loss.append(loss)
        
        loss = np.average(rec_loss)
        
        return loss
    
    def infer_ctr(self):
        
        cfg = self.cfg
        
        if cfg.dataset == 'codraw':
            self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0_e.pt'))
        elif cfg.dataset == 'iclevr':
            self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0_e.pt'))
        
        dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
        dataloader = DataLoader(dataset, 
                                batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
        
        if cfg.dataset == 'codraw':
            dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            dataloader.collate_fn = clevr_dataset.collate_data
        
        glove_key = list(dataset.glove_key.keys())
        
        for batch in dataloader:
            rec_out, loss = self.model.train_ctr(batch, -1, -1, self.visualizer, self.logger, 
                                                 is_eval=True, is_infer=True)
            
            rec_out = np.argmax(rec_out, axis=3)
            
            os.system('mkdir ins_result')
            
            for i in range(30):
                
                os.system('mkdir ins_result/%d' % (i))
                
                F = open('ins_result/%d/ins.txt' % (i), 'w')
                for j in range(rec_out.shape[1]):
                    print([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])])
                    print([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])])
                    print()
                    
                    F.write(' '.join([glove_key[rec_out[i, j, k]] for k in range(rec_out.shape[2])]))
                    F.write('\n')
                    F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(rec_out.shape[2])]))
                    F.write('\n')
                    F.write('\n')
                    
                    TV.utils.save_image(batch['image'][i, j].data, 'ins_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1))
                
                print('\n----------------\n')
                F.close()
            
            break
        
        os.system('tar zcvf ins_result.tar.gz ins_result')
    
    def infer_gen(self):
        
        cfg = self.cfg
        
        if cfg.dataset == 'codraw':
            self.model.ctr.E.load_state_dict(torch.load('models/codraw_1.0.pt'))
        elif cfg.dataset == 'iclevr':
            self.model.ctr.E.load_state_dict(torch.load('models/iclevr_1.0.pt'))
        
        dataset = DATASETS[cfg.dataset](path=keys[cfg.val_dataset], cfg=cfg, img_size=cfg.img_size)
        dataloader = DataLoader(dataset, 
                                batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
        
        if cfg.dataset == 'codraw':
            dataloader.collate_fn = codraw_dataset.collate_data
        elif cfg.dataset == 'iclevr':
            dataloader.collate_fn = clevr_dataset.collate_data
        
        glove_key = list(dataset.glove_key.keys())
        
        for batch in dataloader:
            rec_out = self.model.infer_gen(batch)
            
            os.system('mkdir gen_result')
            
            for i in range(30):
                
                os.system('mkdir gen_result/%d' % (i))
                
                F = open('gen_result/%d/ins.txt' % (i), 'w')
                for j in range(rec_out.shape[1]):
                    F.write(' '.join([glove_key[int(batch['turn_word'][i, j, k].detach().cpu().numpy())] for k in range(batch['turn_word'].shape[2])]))
                    F.write('\n')
                    
                    TV.utils.save_image(batch['image'][i, j].data, 'gen_result/%d/_%d.png' % (i, j), normalize=True, range=(-1, 1))
                    TV.utils.save_image(torch.from_numpy(rec_out[i, j]).data, 'gen_result/%d/%d.png' % (i, j), normalize=True, range=(-1, 1))
                    
                F.close()
            
            break
        
        os.system('tar zcvf gen_result.tar.gz gen_result')
예제 #4
0
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN, self).__init__()

        # region Models-Instantiation

        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False), dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(),
            cfg.generator_lr,
            cfg.generator_beta1,
            cfg.generator_beta2,
            cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(),
            cfg.discriminator_lr,
            cfg.discriminator_beta1,
            cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](
            self.rnn.parameters(),
            cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(),
            cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params,
            cfg.feature_encoder_lr
        )

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)
예제 #5
0
class RecurrentGAN():
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN, self).__init__()

        # region Models-Instantiation

        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False), dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(),
            cfg.generator_lr,
            cfg.generator_beta1,
            cfg.generator_beta2,
            cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(),
            cfg.discriminator_lr,
            cfg.discriminator_beta1,
            cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](
            self.rnn.parameters(),
            cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(),
            cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params,
            cfg.feature_encoder_lr
        )

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

    def train_batch(self, batch, epoch, iteration, visualizer, logger):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        """
        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image.unsqueeze(0) \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)
        prev_objects = torch.zeros(batch_size, self.cfg.num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            image_feature_map, image_vec, object_detections = \
                self.image_encoder(prev_image)
            _, current_image_feat, _ = self.image_encoder(image)

            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = \
                self.condition_encoder(turn_embedding,
                                       image_vec,
                                       current_image_feat)

            rnn_condition = rnn_condition.unsqueeze(0)
            output, hidden = self.rnn(rnn_condition,
                                      hidden)

            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(batch_size,
                                                                    output.detach(),
                                                                    image_feature_map)

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             seq_ended,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         seq_ended,
                                         mu,
                                         logvar)

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                rnn_grads.append(rnn_gradient.data.cpu().numpy())
                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(img_encoder_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            teller_images.extend(image[:4].data.numpy())
            drawer_images.extend(fake_image[:4].data.cpu().numpy())
            entities = str.join(',', list(batch['entities'][hamming > 0]))
            added_entities.append(entities)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient, condition_gradient,
                                 img_encoder_gradient, iteration)
            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            self.logger.write(epoch, iteration, d_real, d_fake, d_loss, g_loss)

            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path,
                                self.cfg.exp_name)

            self._save(fake_image[:4], path, epoch,
                       iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def _forward_generator(self, batch_size, condition, image_feature_maps):
        noise = torch.FloatTensor(batch_size,
                                  self.cfg.noise_dim).normal_(0, 1).cuda()

        fake_images, mu, logvar, sigma = self.generator(noise, condition,
                                                        image_feature_maps)

        return fake_images, mu, logvar, sigma

    def _optimize_discriminator(self, real_images, fake_images, prev_image,
                                condition, mask, objects, gp_reg=0, aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """
        wrong_images = torch.cat((real_images[1:],
                                  real_images[0:1]), dim=0)
        wrong_prev = torch.cat((prev_image[1:],
                                prev_image[0:1]), dim=0)

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        d_wrong, _, _ = self.discriminator(wrong_images, condition,
                                           wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(d_real,
                                                           d_fake,
                                                           d_wrong,
                                                           aux_real,
                                                           aux_fake, objects,
                                                           aux_reg, mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(d_real, real_images,
                                                         mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar

    def _optimize_generator(self, fake_images, prev_image, condition, objects, aux_reg,
                            mask, mu, logvar):
        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask)

        g_loss.backward(retain_graph=True)
        gen_grad_norm = _recurrent_gan.get_grad_norm(self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar

    def _optimize_rnn(self):
        torch.nn.utils.clip_grad_norm_(self.rnn.parameters(), self.cfg.grad_clip)
        rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
        self.rnn_optimizer.step()
        self.rnn.zero_grad()

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(), self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(self.condition_encoder.parameters())
        ie_grad_norm = _recurrent_gan.get_grad_norm(self.image_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()
        self.image_encoder.zero_grad()
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm

    def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real, aux_fake,
                                   objects, aux_reg, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        d_loss = []
        aux_losses = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.discriminator(d_real[b], d_fake[b], d_wrong[b],
                                                           self.cfg.wrong_fake_ratio)
                if aux_reg > 0:
                    aux_loss = aux_reg * (self.aux_criterion(aux_real[b], objects[b]).mean() +
                                          self.aux_criterion(aux_fake[b], objects[b]).mean())
                    sample_loss += aux_loss
                    aux_losses.append(aux_loss)

                d_loss.append(sample_loss)

        d_loss = torch.stack(d_loss).mean()

        if len(aux_losses) > 0:
            aux_losses = torch.stack(aux_losses).mean()
        else:
            aux_losses = 0

        return d_loss, aux_losses

    def _generator_masked_loss(self, d_fake, aux_fake, objects, aux_reg,
                               mu, logvar, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        g_loss = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.generator(d_fake[b])
                if aux_reg > 0:
                    aux_loss = aux_reg * self.aux_criterion(aux_fake[b], objects[b]).mean()
                else:
                    aux_loss = 0
                if mu is not None:
                    kl_loss = self.cfg.cond_kl_reg * kl_penalty(mu[b], logvar[b])
                else:
                    kl_loss = 0

                g_loss.append(sample_loss + aux_loss + kl_loss)

        g_loss = torch.stack(g_loss)
        return g_loss.mean()

    def _masked_gradient_penalty(self, d_real, real_images, mask):
        gp_reg = gradient_penalty(d_real, real_images).mean()
        return gp_reg

    # region Helpers
    def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss,
                     iteration):
        _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss,
                                    aux_loss, iteration)

    def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce,
                        ie, iteration):
        _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc,
                                       gru, ce, ie, iteration)

    def _draw_images(self, visualizer, real, fake, nrow):
        _recurrent_gan.draw_images(self, visualizer, real, fake, nrow)

    def _save(self, fake, path, epoch, iteration):
        _recurrent_gan._save(self, fake, path, epoch, iteration)

    def save_model(self, path, epoch, iteration):
        _recurrent_gan.save_model(self, path, epoch, iteration)

    def load_model(self, snapshot_path):
        _recurrent_gan.load_model(self, snapshot_path)
예제 #6
0
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN_Mingyang, self).__init__()

        # region Models-Instantiation

        ###############################Original DataParallel###################
        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False),
                                   dim=1).cuda()
        # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        #                                           cfg.hidden_dim,
        # batch_first=False), dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()
        #######################################################################
        # self.generator = GeneratorFactory.create_instance(cfg).cuda()

        # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda()

        # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda()
        # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        # #                                           cfg.hidden_dim,
        # # batch_first=False), dim=1).cuda()

        # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda()

        # self.image_encoder = =ImageEncoder(cfg).cuda()

        # self.condition_encoder = ConditionEncoder(cfg).cuda()

        # self.sentence_encoder = SentenceEncoder(cfg).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1,
            cfg.generator_beta2, cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(), cfg.discriminator_lr,
            cfg.discriminator_beta1, cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(),
                                                      cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(), cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params, cfg.feature_encoder_lr)

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        #Added by Mingyang for segmentation loss
        if cfg.balanced_seg:
            label_weights = np.array([
                3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04,
                1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02,
                1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02,
                1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03,
                3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01,
                0.00000000e+00, 8.44408475e-05
            ])
            #reverse the loss
            label_weights = 1 / label_weights
            label_weights[20] = 0
            label_weights = label_weights / np.min(label_weights[:20])
            #convert numpy to tensor
            label_weights = torch.from_numpy(label_weights)
            label_weights = label_weights.type(torch.FloatTensor)
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss(weight=label_weights)).cuda()
        else:
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

        # define unorm
        self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
예제 #7
0
class RecurrentGAN_Mingyang():
    def __init__(self, cfg):
        """A recurrent GAN model, each time step a generated image
        (x'_{t-1}) and the current question q_{t} are fed to the RNN
        to produce the conditioning vector for the GAN.
        The following equations describe this model:

            - c_{t} = RNN(h_{t-1}, q_{t}, x^{~}_{t-1})
            - x^{~}_{t} = G(z | c_{t})
        """
        super(RecurrentGAN_Mingyang, self).__init__()

        # region Models-Instantiation

        ###############################Original DataParallel###################
        self.generator = DataParallel(
            GeneratorFactory.create_instance(cfg)).cuda()

        self.discriminator = DataParallel(
            DiscriminatorFactory.create_instance(cfg)).cuda()

        self.rnn = nn.DataParallel(nn.GRU(cfg.input_dim,
                                          cfg.hidden_dim,
                                          batch_first=False),
                                   dim=1).cuda()
        # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        #                                           cfg.hidden_dim,
        # batch_first=False), dim=1).cuda()

        self.layer_norm = nn.DataParallel(nn.LayerNorm(cfg.hidden_dim)).cuda()

        self.image_encoder = DataParallel(ImageEncoder(cfg)).cuda()

        self.condition_encoder = DataParallel(ConditionEncoder(cfg)).cuda()

        self.sentence_encoder = nn.DataParallel(SentenceEncoder(cfg)).cuda()
        #######################################################################
        # self.generator = GeneratorFactory.create_instance(cfg).cuda()

        # self.discriminator = DiscriminatorFactory.create_instance(cfg).cuda()

        # self.rnn = nn.GRU(cfg.input_dim,cfg.hidden_dim,batch_first=False).cuda()
        # # self.rnn = DistributedDataParallel(nn.GRU(cfg.input_dim,
        # #                                           cfg.hidden_dim,
        # # batch_first=False), dim=1).cuda()

        # self.layer_norm = nn.LayerNorm(cfg.hidden_dim).cuda()

        # self.image_encoder = =ImageEncoder(cfg).cuda()

        # self.condition_encoder = ConditionEncoder(cfg).cuda()

        # self.sentence_encoder = SentenceEncoder(cfg).cuda()

        # endregion

        # region Optimizers

        self.generator_optimizer = OPTIM[cfg.generator_optimizer](
            self.generator.parameters(), cfg.generator_lr, cfg.generator_beta1,
            cfg.generator_beta2, cfg.generator_weight_decay)

        self.discriminator_optimizer = OPTIM[cfg.discriminator_optimizer](
            self.discriminator.parameters(), cfg.discriminator_lr,
            cfg.discriminator_beta1, cfg.discriminator_beta2,
            cfg.discriminator_weight_decay)

        self.rnn_optimizer = OPTIM[cfg.rnn_optimizer](self.rnn.parameters(),
                                                      cfg.rnn_lr)

        self.sentence_encoder_optimizer = OPTIM[cfg.gru_optimizer](
            self.sentence_encoder.parameters(), cfg.gru_lr)

        self.use_image_encoder = cfg.use_fg
        feature_encoding_params = list(self.condition_encoder.parameters())
        if self.use_image_encoder:
            feature_encoding_params += list(self.image_encoder.parameters())

        self.feature_encoders_optimizer = OPTIM['adam'](
            feature_encoding_params, cfg.feature_encoder_lr)

        # endregion

        # region Criterion

        self.criterion = LOSSES[cfg.criterion]()
        self.aux_criterion = DataParallel(torch.nn.BCELoss()).cuda()

        #Added by Mingyang for segmentation loss
        if cfg.balanced_seg:
            label_weights = np.array([
                3.02674201e-01, 1.91545454e-03, 2.90009221e-04, 7.50949673e-04,
                1.08670452e-03, 1.11353785e-01, 4.00971053e-04, 1.06240113e-02,
                1.59590824e-01, 5.38960105e-02, 3.36431602e-02, 3.99029734e-02,
                1.88888847e-02, 2.06441476e-03, 6.33775290e-02, 5.81920411e-03,
                3.79528817e-03, 7.87975754e-02, 2.73547355e-03, 1.08308135e-01,
                0.00000000e+00, 8.44408475e-05
            ])
            #reverse the loss
            label_weights = 1 / label_weights
            label_weights[20] = 0
            label_weights = label_weights / np.min(label_weights[:20])
            #convert numpy to tensor
            label_weights = torch.from_numpy(label_weights)
            label_weights = label_weights.type(torch.FloatTensor)
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss(weight=label_weights)).cuda()
        else:
            self.seg_criterion = DataParallel(
                torch.nn.CrossEntropyLoss()).cuda()

        # endregion

        self.cfg = cfg
        self.logger = Logger(cfg.log_path, cfg.exp_name)

        # define unorm
        self.unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

        # define the label distribution

    def train_batch(self,
                    batch,
                    epoch,
                    iteration,
                    visualizer,
                    logger,
                    total_iters=0,
                    current_batch_t=0):
        """
        The training scheme follows the following:
            - Discriminator and Generator is updated every time step.
            - RNN, SentenceEncoder and ImageEncoder parameters are
            updated every sequence
        """
        batch_size = len(batch['image'])
        max_seq_len = batch['image'].size(1)

        prev_image = torch.FloatTensor(batch['background'])
        prev_image = prev_image \
            .repeat(batch_size, 1, 1, 1)
        disc_prev_image = prev_image
        # print("disc_prev_image size is: {}".format(disc_prev_image.shape))

        # Initial inputs for the RNN set to zeros
        hidden = torch.zeros(1, batch_size, self.cfg.hidden_dim)
        prev_objects = torch.zeros(batch_size, self.cfg.num_objects)

        teller_images = []
        drawer_images = []
        added_entities = []

        #print("max sequence length of current batch: {}".format(max_seq_len))
        for t in range(max_seq_len):
            image = batch['image'][:, t]
            turns_word_embedding = batch['turn_word_embedding'][:, t]
            turns_lengths = batch['turn_lengths'][:, t]
            objects = batch['objects'][:, t]
            seq_ended = t > (batch['dialog_length'] - 1)

            image_feature_map, image_vec, object_detections = \
                self.image_encoder(prev_image)
            _, current_image_feat, _ = self.image_encoder(image)

            # print("[image_encoder] image_feature_map shape is: {}".format(image_feature_map.shape))
            # print("[image_encoder] image_vec shape is: {}".format(image_vec.shape))

            turn_embedding = self.sentence_encoder(turns_word_embedding,
                                                   turns_lengths)
            rnn_condition, current_image_feat = \
                self.condition_encoder(turn_embedding,
                                       image_vec,
                                       current_image_feat)

            rnn_condition = rnn_condition.unsqueeze(0)
            # self.rnn.flatten_parameters()  # Added by Mingyang to Resolve the
            # Warning
            self.rnn.module.flatten_parameters()
            output, hidden = self.rnn(rnn_condition, hidden)

            output = output.squeeze(0)
            output = self.layer_norm(output)

            fake_image, mu, logvar, sigma = self._forward_generator(
                batch_size, output.detach(), image_feature_map)

            #print("[image_generator] fake_image size is: {}".format(fake_image.shape))
            #print("[image_generator] fake_image_one_pixel is: {}".format(fake_image[0,:,0,0]))

            visualizer.track_sigma(sigma)

            hamming = objects - prev_objects
            hamming = torch.clamp(hamming, min=0)

            # print(image.shape)
            # print(disc_prev_image.shape)
            d_loss, d_real, d_fake, aux_loss, discriminator_gradient = \
                self._optimize_discriminator(image,
                                             fake_image.detach(),
                                             disc_prev_image,
                                             output,
                                             seq_ended,
                                             hamming,
                                             self.cfg.gp_reg,
                                             self.cfg.aux_reg)

            # append the segmentation loss accordingly
            if re.search(r"seg", self.cfg.gan_type):
                assert self.cfg.seg_reg > 0, "the sge_reg must be larger than 0"
                if self.cfg.gan_type == "recurrent_gan_mingyang_img64_seg":
                    #The size of seg_fake is adjusted to (Batch, N, C)
                    seg_fake = fake_image.view(fake_image.size(0),
                                               fake_image.size(1),
                                               -1).permute(0, 2, 1)
                    #The size of the seg_gt is obtained from image
                    seg_gt = torch.argmax(image, dim=1).view(image.size(0), -1)

            else:
                assert self.cfg.seg_reg == 0, "the sge_reg must be equal to 0"
                seg_fake = None
                seg_gt = None


            g_loss, generator_gradient = \
                self._optimize_generator(fake_image,
                                         disc_prev_image.detach(),
                                         output.detach(),
                                         objects,
                                         self.cfg.aux_reg,
                                         seq_ended,
                                         mu,
                                         logvar,
                                         self.cfg.seg_reg,
                                         seg_fake,
                                         seg_gt)
            #return

            if self.cfg.teacher_forcing:
                prev_image = image
            else:
                prev_image = fake_image

            disc_prev_image = image
            prev_objects = objects

            if (t + 1) % 2 == 0:
                prev_image = prev_image.detach()

            rnn_grads = []
            gru_grads = []
            condition_encoder_grads = []
            img_encoder_grads = []

            if t == max_seq_len - 1:
                rnn_gradient, gru_gradient, condition_gradient,\
                    img_encoder_gradient = self._optimize_rnn()

                rnn_grads.append(rnn_gradient.data.cpu().numpy())
                gru_grads.append(gru_gradient.data.cpu().numpy())
                condition_encoder_grads.append(
                    condition_gradient.data.cpu().numpy())

                if self.use_image_encoder:
                    img_encoder_grads.append(
                        img_encoder_gradient.data.cpu().numpy())

                visualizer.track(d_real, d_fake)

            hamming = hamming.data.cpu().numpy()[0]
            # teller_images.extend(image[:4].data.cpu().numpy())
            # drawer_images.extend(fake_image[:4].data.cpu().numpy())
            new_teller_images = []
            for x in image[:4].data.cpu():
                # print(x.shape)
                # new_x = self.unorm(x)
                # new_x = transforms.ToPILImage()(new_x).convert('RGB')
                # # new_x = np.array(new_x)[..., ::-1]
                # new_x = np.moveaxis(np.array(new_x), -1, 0)

                if self.cfg.image_gen_mode == "real":
                    new_x = self.unormalize(x)
                elif self.cfg.image_gen_mode == "segmentation":
                    new_x = self.unormalize_segmentation(x.data.numpy())
                elif self.cfg.image_gen_mode == "segmentation_onehot":
                    #TODO: Implement the functino to convert new_x to colored_image
                    new_x = self.unormalize_segmentation_onehot(
                        x.data.cpu().numpy())
                    #print(new_x.shape)
                    #return

                # print(new_x.shape)
                new_teller_images.append(new_x)
            teller_images.extend(new_teller_images)

            new_drawer_images = []
            for x in fake_image[:4].data.cpu():
                # print(x.shape)
                # new_x = self.unorm(x)
                # new_x = transforms.ToPILImage()(new_x).convert('RGB')
                # # new_x = np.array(new_x)[..., ::-1]
                # new_x = np.moveaxis(np.array(new_x), -1, 0)

                if self.cfg.image_gen_mode == "real":
                    new_x = self.unormalize(x)
                elif self.cfg.image_gen_mode == "segmentation":
                    new_x = self.unormalize_segmentation(x.data.cpu().numpy())
                elif self.cfg.image_gen_mode == "segmentation_onehot":
                    #TODO: Implement the functino to convert new_x to colored_image
                    new_x = self.unormalize_segmentation_onehot(
                        x.data.cpu().numpy())

                # print(new_x.shape)
                new_drawer_images.append(new_x)
            drawer_images.extend(new_drawer_images)
            # drawer_images.extend(fake_image[:4].data.cpu().numpy())
            # print(drawer_images[0].shape)

            # entities = str.join(',', list(batch['entities'][hamming > 0]))
            # added_entities.append(entities)
        # print(iteration)

        if iteration % self.cfg.vis_rate == 0:
            visualizer.histogram()
            self._plot_losses(visualizer, g_loss, d_loss, aux_loss, iteration)
            rnn_gradient = np.array(rnn_grads).mean()
            gru_gradient = np.array(gru_grads).mean()
            condition_gradient = np.array(condition_encoder_grads).mean()
            img_encoder_gradient = np.array(img_encoder_grads).mean()
            rnn_grads, gru_grads = [], []
            condition_encoder_grads, img_encoder_grads = [], []
            self._plot_gradients(visualizer, rnn_gradient, generator_gradient,
                                 discriminator_gradient, gru_gradient,
                                 condition_gradient, img_encoder_gradient,
                                 iteration)

            self._draw_images(visualizer, teller_images, drawer_images, nrow=4)
            # self.logger.write(epoch, "{}/{}".format(iteration,total_iters),
            # d_real, d_fake, d_loss, g_loss)
            remaining_time = str(
                datetime.timedelta(seconds=current_batch_t *
                                   (total_iters - iteration)))
            self.logger.write(epoch,
                              "{}/{}".format(iteration, total_iters),
                              d_real,
                              d_fake,
                              d_loss,
                              g_loss,
                              expected_finish_time=remaining_time)
            if isinstance(batch['turn'], list):
                batch['turn'] = np.array(batch['turn']).transpose()

            visualizer.write(batch['turn'][0])
            # visualizer.write(added_entities, var_name='entities')
            teller_images = []
            drawer_images = []

        if iteration % self.cfg.save_rate == 0:
            path = os.path.join(self.cfg.log_path, self.cfg.exp_name)

            # self._save(fake_image[:4], path, epoch,
            #            iteration)
            if not self.cfg.debug:
                self.save_model(path, epoch, iteration)

    def _forward_generator(self, batch_size, condition, image_feature_maps):
        noise = torch.FloatTensor(batch_size,
                                  self.cfg.noise_dim).normal_(0, 1).cuda()

        fake_images, mu, logvar, sigma = self.generator(
            noise, condition, image_feature_maps)

        return fake_images, mu, logvar, sigma

    def _optimize_discriminator(self,
                                real_images,
                                fake_images,
                                prev_image,
                                condition,
                                mask,
                                objects,
                                gp_reg=0,
                                aux_reg=0):
        """Discriminator is updated every step independent of batch_size
        RNN and the generator
        """
        wrong_images = torch.cat((real_images[1:], real_images[0:1]), dim=0)
        wrong_prev = torch.cat((prev_image[1:], prev_image[0:1]), dim=0)

        self.discriminator.zero_grad()
        real_images.requires_grad_()

        d_real, aux_real, _ = self.discriminator(real_images, condition,
                                                 prev_image)
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        d_wrong, _, _ = self.discriminator(wrong_images, condition, wrong_prev)

        d_loss, aux_loss = self._discriminator_masked_loss(
            d_real, d_fake, d_wrong, aux_real, aux_fake, objects, aux_reg,
            mask)

        d_loss.backward(retain_graph=True)
        if gp_reg:
            reg = gp_reg * self._masked_gradient_penalty(
                d_real, real_images, mask)
            reg.backward(retain_graph=True)

        grad_norm = _recurrent_gan.get_grad_norm(
            self.discriminator.parameters())
        self.discriminator_optimizer.step()

        d_loss_scalar = d_loss.item()
        d_real_np = d_real.cpu().data.numpy()
        d_fake_np = d_fake.cpu().data.numpy()
        aux_loss_scalar = aux_loss.item() if isinstance(
            aux_loss, torch.Tensor) else aux_loss
        grad_norm_scalar = grad_norm.item()
        del d_loss
        del d_real
        del d_fake
        del aux_loss
        del grad_norm
        gc.collect()

        return d_loss_scalar, d_real_np, d_fake_np, aux_loss_scalar, grad_norm_scalar

    def _optimize_generator(self,
                            fake_images,
                            prev_image,
                            condition,
                            objects,
                            aux_reg,
                            mask,
                            mu,
                            logvar,
                            seg_reg=0,
                            seg_fake=None,
                            seg_gt=None):
        self.generator.zero_grad()
        d_fake, aux_fake, _ = self.discriminator(fake_images, condition,
                                                 prev_image)
        g_loss = self._generator_masked_loss(d_fake, aux_fake, objects,
                                             aux_reg, mu, logvar, mask,
                                             seg_reg, seg_fake, seg_gt)

        g_loss.backward(retain_graph=True)
        gen_grad_norm = _recurrent_gan.get_grad_norm(
            self.generator.parameters())

        self.generator_optimizer.step()

        g_loss_scalar = g_loss.item()
        gen_grad_norm_scalar = gen_grad_norm.item()

        del g_loss
        del gen_grad_norm
        gc.collect()

        return g_loss_scalar, gen_grad_norm_scalar

    def _optimize_rnn(self):
        torch.nn.utils.clip_grad_norm_(self.rnn.parameters(),
                                       self.cfg.grad_clip)
        rnn_grad_norm = _recurrent_gan.get_grad_norm(self.rnn.parameters())
        self.rnn_optimizer.step()
        self.rnn.zero_grad()

        gru_grad_norm = None
        torch.nn.utils.clip_grad_norm_(self.sentence_encoder.parameters(),
                                       self.cfg.grad_clip)
        gru_grad_norm = _recurrent_gan.get_grad_norm(
            self.sentence_encoder.parameters())
        self.sentence_encoder_optimizer.step()
        self.sentence_encoder.zero_grad()

        ce_grad_norm = _recurrent_gan.get_grad_norm(
            self.condition_encoder.parameters())
        ie_grad_norm = _recurrent_gan.get_grad_norm(
            self.image_encoder.parameters())
        self.feature_encoders_optimizer.step()
        self.condition_encoder.zero_grad()
        self.image_encoder.zero_grad()
        return rnn_grad_norm, gru_grad_norm, ce_grad_norm, ie_grad_norm

    def _discriminator_masked_loss(self, d_real, d_fake, d_wrong, aux_real,
                                   aux_fake, objects, aux_reg, mask):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding"""
        d_loss = []
        aux_losses = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.discriminator(
                    d_real[b], d_fake[b], d_wrong[b],
                    self.cfg.wrong_fake_ratio)
                if aux_reg > 0:
                    aux_loss = aux_reg * (
                        self.aux_criterion(aux_real[b], objects[b]).mean() +
                        self.aux_criterion(aux_fake[b], objects[b]).mean())
                    sample_loss += aux_loss
                    aux_losses.append(aux_loss)

                d_loss.append(sample_loss)

        d_loss = torch.stack(d_loss).mean()

        if len(aux_losses) > 0:
            aux_losses = torch.stack(aux_losses).mean()
        else:
            aux_losses = 0

        return d_loss, aux_losses

    def _generator_masked_loss(self,
                               d_fake,
                               aux_fake,
                               objects,
                               aux_reg,
                               mu,
                               logvar,
                               mask,
                               seg_reg=0,
                               seg_fake=None,
                               seg_gt=None):
        """Accumulates losses only for sequences that have not ended
        to avoid back-propagation through padding
        Append the segmentation loss to the model.
        seg_fake: (1*C*H*W)
        seg_gt: (1*H*W)
        """
        g_loss = []
        for b, ended in enumerate(mask):
            if not ended:
                sample_loss = self.criterion.generator(d_fake[b])
                if aux_reg > 0:
                    aux_loss = aux_reg * \
                        self.aux_criterion(aux_fake[b], objects[b]).mean()
                else:
                    aux_loss = 0
                if mu is not None:
                    kl_loss = self.cfg.cond_kl_reg * \
                        kl_penalty(mu[b], logvar[b])
                else:
                    kl_loss = 0
                #Append a seg_loss to the total generator loss
                if seg_reg > 0:
                    #TODO: Implement the Segmentation Loss here
                    seg_loss = seg_reg * self.seg_criterion(
                        seg_fake[b], seg_gt[b]
                    )  #By default it should just give a mean number
                    #print(seg_loss)
                else:
                    seg_loss = 0

                g_loss.append(sample_loss + aux_loss + kl_loss + seg_loss)

        g_loss = torch.stack(g_loss)
        return g_loss.mean()

    def _masked_gradient_penalty(self, d_real, real_images, mask):
        gp_reg = gradient_penalty(d_real, real_images).mean()
        return gp_reg

    # region Helpers
    def _plot_losses(self, visualizer, g_loss, d_loss, aux_loss, iteration):
        _recurrent_gan._plot_losses(self, visualizer, g_loss, d_loss, aux_loss,
                                    iteration)

    def _plot_gradients(self, visualizer, rnn, gen, disc, gru, ce, ie,
                        iteration):
        _recurrent_gan._plot_gradients(self, visualizer, rnn, gen, disc, gru,
                                       ce, ie, iteration)

    def _draw_images(self, visualizer, real, fake, nrow):
        _recurrent_gan.draw_images_gandraw(self, visualizer, real, fake,
                                           nrow)  # Changed by Mingyang Zhou

    def _save(self, fake, path, epoch, iteration):
        _recurrent_gan._save(self, fake, path, epoch, iteration)

    def save_model(self, path, epoch, iteration):
        _recurrent_gan.save_model(self, path, epoch, iteration)

    def load_model(self, snapshot_path):
        _recurrent_gan.load_model(self, snapshot_path)

    def unormalize(self, x):
        """
        unormalize the image
        """
        new_x = self.unorm(x)
        new_x = transforms.ToPILImage()(new_x).convert('RGB')
        # new_x = np.array(new_x)[..., ::-1]
        new_x = np.moveaxis(np.array(new_x), -1, 0)
        return new_x

    def unormalize_segmentation(self, x):
        new_x = (x + 1) * 127.5
        # new_x = new_x.transpose(1, 2, 0)[..., ::-1]
        return new_x

    def unormalize_segmentation_onehot(self, x):
        """
        Convert the segmentation into image
        """

        LABEL2COLOR = {
            0: {
                "name": "sky",
                "color": np.array([134, 193, 46])
            },
            1: {
                "name": "dirt",
                "color": np.array([30, 22, 100])
            },
            2: {
                "name": "gravel",
                "color": np.array([163, 164, 153])
            },
            3: {
                "name": "mud",
                "color": np.array([35, 90, 74])
            },
            4: {
                "name": "sand",
                "color": np.array([196, 15, 241])
            },
            5: {
                "name": "clouds",
                "color": np.array([198, 182, 115])
            },
            6: {
                "name": "fog",
                "color": np.array([76, 60, 231])
            },
            7: {
                "name": "hill",
                "color": np.array([190, 128, 82])
            },
            8: {
                "name": "mountain",
                "color": np.array([122, 101, 17])
            },
            9: {
                "name": "river",
                "color": np.array([97, 140, 33])
            },
            10: {
                "name": "rock",
                "color": np.array([90, 90, 81])
            },
            11: {
                "name": "sea",
                "color": np.array([255, 252, 51])
            },
            12: {
                "name": "snow",
                "color": np.array([51, 255, 252])
            },
            13: {
                "name": "stone",
                "color": np.array([106, 107, 97])
            },
            14: {
                "name": "water",
                "color": np.array([0, 255, 0])
            },
            15: {
                "name": "bush",
                "color": np.array([204, 113, 46])
            },
            16: {
                "name": "flower",
                "color": np.array([0, 0, 255])
            },
            17: {
                "name": "grass",
                "color": np.array([255, 0, 0])
            },
            18: {
                "name": "straw",
                "color": np.array([255, 51, 252])
            },
            19: {
                "name": "tree",
                "color": np.array([255, 51, 175])
            },
            20: {
                "name": "wood",
                "color": np.array([66, 18, 120])
            },
            21: {
                "name": "road",
                "color": np.array([255, 255, 0])
            },
        }
        seg_map = np.argmax(x, axis=0)
        new_x = np.zeros((3, seg_map.shape[0], seg_map.shape[1]),
                         dtype=np.uint8)
        for i in range(seg_map.shape[0]):
            for j in range(seg_map.shape[1]):
                new_x[:, i, j] = LABEL2COLOR[seg_map[i, j]]["color"]
        return new_x