Exemplo n.º 1
0
    def test(self, netG, noise, epoch, stage=1):

        temp_test_captions = [
            'a man standing behind another man helping him with his tie',
            'a polar bear walking on concrete with rocks in the background',
            'a desk with a keyboard, monitor, and laptop sitting on it',
            'a man in a field with sheep and two dogs',
            'a small kid watches as a man makes some food',
            'a bedroom with a vanity and tv on a dresser',
            'a giraffe stands in the middle of its enclosure',
            'a man hitting a tennis ball with a racket',
            'a black-and-white photo of a destroyed kitchen with a fridge',
            'green lights on a street filled with cars'
        ]

        # test_captions = [caption.split() for caption in temp_test_captions]
        test_captions = convert_to_index(temp_test_captions, self.word2idx)
        batch_size = len(test_captions)

        # if cfg.CUDA:
        #     fixed_noise = fixed_noise.cuda()

        #Anything happening in Dataloader that converts it to a one-hot encoding?
        Ylenghts = [Sample.shape[0] for Sample in test_captions]
        paddedArrayY = torch.FloatTensor(max(Ylenghts), len(Ylenghts)).zero_()
        for idx, Sample in enumerate(test_captions):
            paddedArrayY[:Sample.shape[0], idx] = torch.from_numpy(Sample)

        if cfg.CUDA:
            paddedArrayY = Variable(paddedArrayY.type(torch.LongTensor).cuda())

        sent_hidden, _, _ = self.CTallmodel(paddedArrayY, Ylenghts,
                                            paddedArrayY, paddedArrayY)

        #######################################################
        # (2) Generate fake images
        ######################################################
        inputs = (sent_hidden, noise)
        _, fake_imgs, mu, logvar = nn.parallel.data_parallel(
            netG, inputs, self.gpus)

        save_img_results(None, fake_imgs, epoch, self.test_img_dir, None, None)
Exemplo n.º 2
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD, start_epoch = self.load_network_stageI()
        else:
            netG, netD, start_epoch = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerG, optimizerD = self.load_optimizers(netG, netD)

        count = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.float().cuda()

                ######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ######################################################
                # (3) Update D network
                ######################################################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ######################################################
                # (2) Update G network
                ######################################################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
                save_optimizer(optimizerG, optimizerD, self.model_dir)

        save_model(netG, netD, self.max_epoch, self.model_dir)
        self.summary_writer.close()
Exemplo n.º 3
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        ####
        #netD_std = 0.1
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5

                #### stand deviation decay
                #netD_std *= 0.5

                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding).float()
                #	print(txt_embedding.size())
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                if stage == 1:
                    _, fake_imgs, mu, logvar, _ = \
                 nn.parallel.data_parallel(netG, inputs, self.gpus)
                else:
                    _, fake_imgs, mu, logvar = \
                                      nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                #### A little noise for images passed to discriminator
                #fake_imgs = fake_imgs + torch.cuda.FloatTensor(fake_imgs.size()).normal_(0,netD_std)

                #### update D twice
                for D_update in range(2):
                    netD.zero_grad()
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                            real_labels, fake_labels,
                                            mu, self.gpus)
                    errD.backward()
                    optimizerD.step()

                #### update D with reversed labels
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                       compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                fake_labels, real_labels,
                                mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    #summary_D = summary.scalar('D_loss', errD.data[0])
                    #print(summary_D)
                    #summary_D_r = summary.scalar('D_loss_real', errD_real)
                    #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    #summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    #summary_G = summary.scalar('G_loss', errG.data[0])
                    #summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    #self.summary_writer.add_summary(summary_D, count)
                    #self.summary_writer.add_summary(summary_D_r, count)
                    #self.summary_writer.add_summary(summary_D_w, count)
                    #self.summary_writer.add_summary(summary_D_f, count)
                    #self.summary_writer.add_summary(summary_G, count)
                    #self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    if stage == 1:
                        lr_fake, fake, _, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                    else:
                        lr_fake, fake, _, _ = \
                                          nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 4
0
    def train(self, data_loader, stage=1):

        logger = Logger('./logs_CS_GAN')
        image_transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize([64, 64]),
            transforms.ToTensor()
        ])

        CT_update = 35 if cfg.CTModel == '' else 0
        print("Training CT model for ", CT_update)

        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        #######
        nz = cfg.Z_DIM if not cfg.CAP.USE else cfg.CAP.Z_DIM
        batch_size = self.batch_size
        flags = Variable(torch.cuda.FloatTensor([-1.0] * batch_size))
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     requires_grad=False)
        fixed_noise_test = \
            Variable(torch.FloatTensor(10, nz).normal_(0, 1),
                     requires_grad=False)

        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))

        #Gaussian noise input added to the input images to the disc
        noise_input = Variable(
            torch.zeros(batch_size, 3, cfg.FAKEIMSIZE, cfg.FAKEIMSIZE))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
            noise_input = noise_input.cuda()
            flags.cuda()

        epsilon = 0.999
        epsilon_decay = 0.99
        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR,betas=(0.5, 0.999))
        netG_para = []

        # self.emb_model=EMB(512,128)
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####Optimizers for CT c                             ##########################TODO:PRINT PARAMETERS!!!!
        optimizerCTallmodel = optim.Adam(self.CTallmodel.parameters(),
                                         lr=0.0001,
                                         weight_decay=0.00001,
                                         betas=(0.5, 0.999))
        optimizerCTenc = optim.Adam(self.CTencoder.parameters(),
                                    lr=0.0001,
                                    weight_decay=0.00001,
                                    betas=(0.5, 0.999))
        count = 0
        len_dataset = len(data_loader)

        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
            print("Started training for new epoch")
            optimizerCTallmodel.zero_grad()
            ct_epoch_loss = 0
            emb_loss = 0
            epoch_count = 0
            for i, data in enumerate(data_loader):
                ######################################################
                # (1) Prepare training data
                ######################################################

                real_img_cpu, sentences, paddedArrayPrev, maskArrayPrev, paddedArrayCurr, Currlenghts, paddedArrayNext, maskArrayNext = data
                self.CTallmodel.encoder.hidden = self.CTallmodel.encoder.hidden_init(
                    paddedArrayCurr.size(1))
                real_imgs = Variable(real_img_cpu)
                paddedArrayCurr = Variable(
                    paddedArrayCurr.type(torch.LongTensor))
                paddedArrayNext_input = Variable(paddedArrayNext[:-1, :].type(
                    torch.LongTensor))
                paddedArrayPrev_input = Variable(paddedArrayPrev[:-1, :].type(
                    torch.LongTensor))

                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    paddedArrayCurr = paddedArrayCurr.cuda()
                    paddedArrayNext_input = paddedArrayNext_input.cuda()
                    paddedArrayPrev_input = paddedArrayPrev_input.cuda()
                inputs_CT = (paddedArrayCurr, Currlenghts,
                             paddedArrayPrev_input, paddedArrayNext_input)
                # sent_hidden, logits_prev, logits_next = self.CTallmodel(paddedArrayCurr, Currlenghts, paddedArrayPrev_input, paddedArrayNext_input)
                sent_hidden, logits_prev, logits_next = nn.parallel.data_parallel(
                    self.CTallmodel, inputs_CT, self.gpus)
                #Optimizing over Concurrent model
                if (epoch < CT_update):
                    logits_prev = logits_prev.contiguous().view(
                        -1,
                        logits_prev.size()[2])
                    logits_next = logits_next.contiguous().view(
                        -1,
                        logits_next.size()[2])

                    Y_prev = paddedArrayPrev[1:, :]
                    Y_prev = Y_prev.contiguous().view(-1)

                    Y_next = paddedArrayNext[1:, :]
                    Y_next = Y_next.contiguous().view(-1)

                    maskArrayPrev = maskArrayPrev[1:, :]
                    maskArrayPrev = maskArrayPrev.contiguous().view(-1)

                    maskArrayNext = maskArrayNext[1:, :]
                    maskArrayNext = maskArrayNext.contiguous().view(-1)

                    ind_prev = torch.nonzero(maskArrayPrev, out=None).squeeze()
                    ind_next = torch.nonzero(maskArrayNext, out=None).squeeze()

                    if torch.cuda.is_available():
                        ind_prev = ind_prev.cuda()
                        ind_next = ind_next.cuda()

                    valid_target_prev = torch.index_select(
                        Y_prev, 0,
                        ind_prev.type(torch.LongTensor)).type(torch.LongTensor)
                    valid_output_prev = torch.index_select(
                        logits_prev, 0, Variable(ind_prev))

                    valid_target_next = torch.index_select(
                        Y_next, 0,
                        ind_next.type(torch.LongTensor)).type(torch.LongTensor)
                    valid_output_next = torch.index_select(
                        logits_next, 0, Variable(ind_next))

                    if torch.cuda.is_available():
                        valid_output_prev = valid_output_prev.cuda()
                        valid_output_next = valid_output_next.cuda()

                        valid_target_prev = valid_target_prev.cuda()
                        valid_target_next = valid_target_next.cuda()

                    loss_prev = self.CTloss(valid_output_prev,
                                            Variable(valid_target_prev))
                    loss_next = self.CTloss(valid_output_next,
                                            Variable(valid_target_next))

                    self.CTallmodel.zero_grad()
                    optimizerCTallmodel.zero_grad()
                    loss = loss_prev + loss_next
                    loss.backward(retain_graph=True)
                    ct_epoch_loss += loss.data[0]
                    nn.utils.clip_grad_norm(self.CTallmodel.parameters(), 0.25)
                    optimizerCTallmodel.step()

                if epoch >= CT_update:
                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    inputs = (sent_hidden, noise)
                    _, fake_imgs, mu, logvar = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus) #### TODO: Check Shapes->Checked

                    # _,fake_imgs,mu,logvar=netG(inputs[0],inputs[1])
                    #######################################################
                    # (2.1) Generate captions for fake images
                    ######################################################
                    if self.cap_model_bool:
                        sents, h_sent = self.eval_utils.captioning_model(
                            fake_imgs, self.cap_model, self.vocab_cap,
                            self.my_resnet, self.eval_kwargs)
                        h_sent_var = Variable(torch.FloatTensor(h_sent)).cuda()
                    # input_layer = tf.stack([preprocess_for_train(i) for i in real_imgs], axis=0)
                    real_imgs = Variable(
                        torch.stack([
                            image_transform_train(img.data.cpu()).cuda()
                            for img in real_imgs
                        ],
                                    dim=0))

                    ############################
                    # (3) Update D network
                    ###########################

                    if random.uniform(0, 1) < epsilon and cfg.GAN.ADD_NOISE:
                        epsilon *= epsilon_decay
                        noise_input.data.normal_(0, 1)
                        fake_imgs = fake_imgs + noise_input
                        real_imgs = real_imgs + noise_input
                    netD.zero_grad()
                    errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                    errD.backward()
                    optimizerD.step()

                    #Label Switching
                    #Trick as of - https://github.com/soumith/ganhacks/issues/14
                    #   if random.uniform(0.1)<epsilon:
                    # netD.zero_grad()
                    # errD, errD_real, errD_wrong, errD_fake = \
                    # compute_discriminator_loss(netD, real_imgs, fake_imgs,
                    #                            fake_labels, real_labels,
                    #                            mu, self.gpus)
                    # errD.backward()
                    # optimizerD.step()
                    ############################
                    # (4) Update G network
                    ###########################
                    if self.cap_model_bool:
                        loss_cos = self.cosEmbLoss(sent_hidden, h_sent_var,
                                                   flags)
                    netG.zero_grad()
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  mu, self.gpus)
                    kl_loss = KL_loss(mu, logvar)

                    if self.cap_model_bool:
                        errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + 10 * loss_cos
                        emb_loss += loss_cos.data[0]
                    else:
                        errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                    errG_total.backward()
                    optimizerG.step()

                count = count + 1
                epoch_count += 1
                if i % 200 == 0:
                    print("Loss CT Model: ", ct_epoch_loss / epoch_count)
                    # print("Emb Loss: ", emb_loss)
            # save the image result for each epoch after embedding model has been trained
            if epoch >= CT_update:
                inputs = (sent_hidden, fixed_noise)

                lr_fake, fake, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                if self.cap_model_bool:
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir,
                                     sentences, sents)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir,
                                         sentences, sents)
                else:
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir,
                                     sentences, None)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir,
                                         sentences, None)
                self.test(netG, fixed_noise_test, epoch)

                end_t = time.time()

                print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
                # logger.scalar_summary('Cosine_loss', emb_loss, epoch+1)
                logger.scalar_summary('errD_loss', errD.data[0] / len_dataset,
                                      epoch + 1)
                logger.scalar_summary('errG_loss', errG.data[0] / len_dataset,
                                      epoch + 1)
                logger.scalar_summary('kl_loss', kl_loss.data[0] / len_dataset,
                                      epoch + 1)

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, self.CTallmodel, epoch, self.model_dir)
            logger.scalar_summary('CT_loss', ct_epoch_loss / len_dataset,
                                  epoch + 1)

        save_model(netG, netD, self.CTallmodel, self.max_epoch, self.model_dir)
Exemplo n.º 5
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    self.summary_writer.add_scalars(main_tag="loss", tag_scalar_dict={
                        'D_loss':errD.cpu().item(),
                        'G_loss':errG_total.cpu().item()
                    }, global_step=count)
                    self.summary_writer.add_scalars(main_tag="D_loss", tag_scalar_dict={
                        "D_loss_real":errD_real,
                        "D_loss_wrong":errD_wrong,
                        "D_loss_fake":errD_fake
                    }, global_step=count)
                    self.summary_writer.add_scalars(main_tag="G_loss", tag_scalar_dict={
                        "G_loss":errG.cpu().item(),
                        "KL_loss":kl_loss.cpu().item()
                    }, global_step=count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    self.summary_writer.add_image(tag="fake_image", 
                        img_tensor=vutils.make_grid(fake_imgs, normalize=True, range=(-1,1)),
                        global_step=count
                    )
                    self.summary_writer.add_image(tag="real_image", 
                        img_tensor=vutils.make_grid(real_img_cpu, normalize=True, range=(-1,1)),
                        global_step=count
                    )
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.cpu().item(), errG.cpu().item(), kl_loss.cpu().item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
Exemplo n.º 6
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                # label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0)
                label_one_hot = torch.FloatTensor(noise.shape[0], max_objects,
                                                  81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)
                        lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                            netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 7
0
    def infer_and_save_imgs(self,
                            netG,
                            real_imgs,
                            mask_tensor,
                            edge_tensor,
                            epoch,
                            img_stamp=None,
                            webpage=None,
                            pre_Z_collect=None):
        """
            Infer and save images
            'mask_img' here represents the input condition
        """

        if cfg.TRAIN.FLAG:
            real_imgs, mask_tensor, edge_tensor, mask_tensor_to_save, \
                    edge_tensor_to_save = \
                    self.__split_tensor(real_imgs, mask_tensor, edge_tensor)
        else:
            mask_tensor_to_save, edge_tensor_to_save = mask_tensor, edge_tensor

        # If no extra image stamp, using the epoch as the stamp
        if img_stamp is None: img_stamp = epoch

        batch_size = int(real_imgs.size(0))

        # Initialize the image collection
        image_collect = [
            None,
        ] * batch_size
        for j in range(batch_size):
            image_collect[j] = []

        for i in range(cfg.SAMPLE_NUM):
            # Prepare preset random latent vector if required
            if pre_Z_collect is None: pre_Z = None
            else:
                if pre_Z_collect.size(0) == 2:  # to do linear interpolation
                    sample_miu = i / float(cfg.SAMPLE_NUM)
                    pre_Z = (math.sin((1 - sample_miu) * self.slerp_theta)/\
                            math.sin(self.slerp_theta)) * pre_Z_collect[0] + \
                            (math.sin(sample_miu * self.slerp_theta)/\
                            math.sin(self.slerp_theta))  * pre_Z_collect[1]

                pre_Z = pre_Z.unsqueeze(dim=0)

            inputs = (mask_tensor, edge_tensor, pre_Z)

            gen_imgs, _ = self._parallel_wrapper(netG, inputs)

            for j in range(batch_size):
                image_collect[j].append(
                    gen_imgs.data.cpu()[j].unsqueeze(dim=0))

        gen_imgs = [torch.cat(x, dim=0) for x in image_collect]

        for j, single_gen_imgs in enumerate(gen_imgs):
            image_data_input = OrderedDict([
                ('real_img', real_imgs[j].unsqueeze(dim=0)),
                ('gen_img', single_gen_imgs),
                ('mask_img', mask_tensor_to_save[j].unsqueeze(dim=0)),
            ])

            if edge_tensor is not None:
                image_data_input['edge_img'] =\
                        edge_tensor_to_save[0].unsqueeze(dim=0)

            if type(img_stamp) == list: img_stamp_now = img_stamp[j]
            elif type(img_stamp) == int: img_stamp_now = img_stamp

            output_grids, img_names \
                    = save_img_results(image_data_input, img_stamp_now,
                    self.image_dir, nrow=cfg.SAMPLE_NUM)

            # Save inferred results on the webpage
            if webpage is not None and self.epoch_compare != img_stamp_now:
                webpage.add_images(img_names, img_names, img_names,
                                   [256, 256 * cfg.SAMPLE_NUM, 256])
                webpage.save()
                self.epoch_compare = img_stamp_now

        if cfg.TRAIN.FLAG:
            image_collection = OrderedDict([
                ('real_image', output_grids[0]),
                ('gen_image', output_grids[1]),
                ('mask_image', output_grids[2]),
            ])

            if len(output_grids) == 4:
                image_collection['edge_image'] = output_grids[3]

            for name, img in image_collection.items():
                self.summary_writer.add_image(name, img, epoch)
Exemplo n.º 8
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM  # 100
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        detectron = Detectron()
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
            #print('check 0')
            for i, data in enumerate(data_loader):
                ######################################################
                # (1) Prepare training data
                ######################################################
                #print('check 1')
                real_img_cpu, txt_embedding, caption = data
                caption = np.moveaxis(np.array(caption), 1, 0)

                #print('check 2')

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                #print('check 3')
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################

                #print(real_imgs.size())

                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)

                fake_img = fake_imgs.cpu().detach().numpy()
                #print(fake_img.shape)

                det_obj_list = detectron.get_labels(fake_img)

                fake_l = Variable(get_ohe(det_obj_list)).cuda()
                real_l = Variable(get_ohe(caption)).cuda()

                det_loss = nn.SmoothL1Loss()(fake_l, real_l)
                errG_total = det_loss + errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    summary_DET = summary.scalar('det_loss', det_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_DET, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 9
0
    def train(self, data_loader):
        netG, netD = self.load_network_stageI()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))

        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))

        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))

        print("Starting training...")
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label = data

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    bbox = bbox.cuda()
                    label_one_hot = label.cuda().float()

                bbox = bbox.view(-1, 4)
                transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float()
                transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], self.max_objects, 2, 3)
                transf_matrices = compute_transformation_matrix(bbox).float()
                transf_matrices = transf_matrices.view(real_imgs.shape[0], self.max_objects, 2, 3)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, transf_matrices_inv, label_one_hot)
                # _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)
                _, fake_imgs = netG(noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               label_one_hot, transf_matrices, transf_matrices_inv, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot,
                                              transf_matrices, transf_matrices_inv, self.gpus)
                errG_total = errG
                errG_total.backward()
                optimizerG.step()

                ############################
                # (3) Log results
                ###########################
                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        inputs = (noise, transf_matrices_inv, label_one_hot)
                        lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                        real_img_cpu = pad_imgs(real_img_cpu)
                        fake = pad_imgs(fake)
                        save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch, self.image_dir)
            with torch.no_grad():
                inputs = (noise, transf_matrices_inv, label_one_hot)
                lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                real_img_cpu = pad_imgs(real_img_cpu)
                fake = pad_imgs(fake)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        if cfg.TRAIN.ADAM:
            optimizerD = \
                optim.Adam(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
            optimizerG = optim.Adam(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR,
                                    betas=(0.5, 0.999))
        else:
            optimizerD = \
                optim.RMSprop(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR)
            optimizerG = \
                optim.RMSprop(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR)

        cnn = models.vgg19(pretrained=True).features
        cnn = nn.Sequential(*list(cnn.children())[0:28])
        gram = GramMatrix()
        if cfg.CUDA:
            cnn.cuda()
            gram.cuda()
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                else:
                    _, fake_imgs, mu, logvar = netG(txt_embedding, noise)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus, cfg.CUDA)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus, cfg.CUDA)
                kl_loss = KL_loss(mu, logvar)
                pixel_loss = PIXEL_loss(real_imgs, fake_imgs)
                if cfg.CUDA:
                    fake_features = nn.parallel.data_parallel(
                        cnn, fake_imgs.detach(), self.gpus)
                    real_features = nn.parallel.data_parallel(
                        cnn, real_imgs.detach(), self.gpus)
                else:
                    fake_features = cnn(fake_imgs)
                    real_features = cnn(real_imgs)
                active_loss = ACT_loss(fake_features, real_features)
                text_loss = TEXT_loss(gram, fake_features, real_features,
                                      cfg.TRAIN.COEFF.TEXT)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + \
                                pixel_loss * cfg.TRAIN.COEFF.PIX + \
                                active_loss * cfg.TRAIN.COEFF.ACT +\
                                text_loss
                errG_total.backward()
                optimizerG.step()
                count = count + 1
                if i % 100 == 0:

                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    summary_Pix = summary.scalar('Pixel_loss',
                                                 pixel_loss.data[0])
                    summary_Act = summary.scalar('Act_loss',
                                                 active_loss.data[0])
                    summary_Text = summary.scalar('Text_loss',
                                                  text_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_Pix, count)
                    self.summary_writer.add_summary(summary_Act, count)
                    self.summary_writer.add_summary(summary_Text, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    if cfg.CUDA:
                        lr_fake, fake, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                    else:
                        lr_fake, fake, _, _ = netG(txt_embedding, fixed_noise)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_Pixel: %.4f
                                     Loss_Activ: %.4f Loss_Text: %.4f
                                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                                     Total Time: %.2fsec
                                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.data[0],
                 errG.data[0], kl_loss.data[0], pixel_loss.data[0],
                 active_loss.data[0], text_loss.data[0], errD_real, errD_wrong,
                 errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 11
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        with torch.no_grad():
            #Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     #volatile=True)
            fixed_noise = \
                torch.FloatTensor(batch_size, nz).normal_(0, 1)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            #print('dataLoader, line 156 trainer.py...........')
            #print(data_loader)
            num_batches = len(data_loader)
            print('Number of batches: ' + str(len(data_loader)))
            for i, data in enumerate(data_loader, 0):

                print('Epoch number: ' + str(epoch) + '\tBatches: ' + str(i) + '/' + str(num_batches), end='\r')
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                #print(txt_embedding.shape)  #(Batch_size,1024)
                #exit(0)
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()
                #print('train line 170')
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                #print('Fake images generated shape = ' + str(fake_imgs.shape))
                #print('Shape of fake image: ' + str(fake_imgs.shape))  [Batch_size, Channels(3), N, N]
                #print('Fake images: ')
                #Display one image
                ### Check this line! How to display image?? ##############
                #plt.imshow(fake_imgs[0].permute(1,2,0).cpu().detach().numpy())
                #exit(0)

                ################################################

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                #print('train line 186')
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()
                #print('train line 203')


                count = count + 1
                if i % 100 == 0:

                    """
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    """
                    ## My lines
                    summary_D = summary.scalar('D_loss', errD.data)
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data)
                    summary_KL = summary.scalar('KL_loss', kl_loss.data)
                    #### End of my lines

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
                    del inputs
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.data, errG.data, kl_loss.data,
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            #    % (epoch, self.max_epoch, i, len(data_loader),
            #       errD.data[0], errG.data[0], kl_loss.data[0],
            #       errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            print('################EPOCH COMPLETED###########')
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)

        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 12
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.data[0], errG.data[0], kl_loss.data[0],
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 13
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####
        startpoint = -1
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict["optimD"])
            optimizerG.load_state_dict(state_dict["optimG"])
            startpoint = state_dict["epoch"]
            print(startpoint)
            print('Load Optim and optimizers as : ', cfg.NET_G)
        ####

        count = 0
        drive_count = 0
        for epoch in range(startpoint + 1, self.max_epoch):
            print('epoch : ', epoch, ' drive_count : ', drive_count)
            epoch_start_time = time.time()
            print(epoch)
            start_t = time.time()
            start_t500 = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            time_to_i = time.time()
            for i, data in enumerate(data_loader, 0):
                # if i >= 3360 :
                #     print ('Last Batches : ' , i)
                # if i < 10 :
                #     print ('first Batches : ' , i)
                # if i == 0 :
                #     print ('Startig! Batch ',i,'from total of 2070' )
                # if i % 10 == 0 and i!=0:
                #     end_t500 = time.time()
                #     print ('Batch Number : ' , i ,' |||||  Toatal Time : ' , (end_t500 - start_t500))
                #     start_t500 = time.time()
                ######################################################
                # (1) Prepare training data
                # if i < 10 :
                #     print (" (1) Prepare training data for batch : " , i)
                ######################################################
                #print ("Prepare training data for batch : " , i)
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                if cfg.CUDA:
                    label_one_hot = torch.cuda.FloatTensor(
                        noise.shape[0], max_objects, 81).fill_(0)
                else:
                    label_one_hot = torch.FloatTensor(noise.shape[0],
                                                      max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # # (2) Generate fake images
                # if i < 10 :
                #     print ("(2)Generate fake images")
                ######################################################

                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                else:
                    print('Hiiiiiiiiiiii')
                    _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise,
                                                       transf_matrices_inv,
                                                       label_one_hot)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # # (3) Update D network
                # if i < 10 :
                #     print("(3) Update D network")
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # # (4) Update G network
                # if i < 10 :
                #     print ("(4) Update G network")
                ###########################
                netG.zero_grad()
                # if i < 10 :
                #     print ("netG.zero_grad")
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    # if i < 10 :
                    #     print ("cgf.STAGE = " , cfg.STAGE)
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                    # if i < 10 :
                    #     print("errG : ",errG)
                kl_loss = KL_loss(mu, logvar)
                # if i < 10 :
                #     print ("kl_loss = " , kl_loss)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                # if i < 10 :
                #     print (" errG_total = " , errG_total )
                errG_total.backward()
                # if i < 10 :
                #     print ("errG_total.backward() ")
                optimizerG.step()
                # if i < 10 :
                #     print ("optimizerG.step() " )

                #print (" i % 500 == 0 :  " , i % 500 == 0 )
                end_t = time.time()
                #print ("batch time : " , (end_t - start_t))
                if i % 500 == 0:
                    #print (" i % 500 == 0" , i % 500 == 0 )
                    count += 1
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    print('epoch     :  ', epoch)
                    print('count     :  ', count)
                    print('  i       :  ', i)
                    print('Time to i : ', time.time() - time_to_i)
                    time_to_i = time.time()
                    print('D_loss : ', errD.item())
                    print('D_loss_real : ', errD_real)
                    print('D_loss_wrong : ', errD_wrong)
                    print('D_loss_fake : ', errD_fake)
                    print('G_loss : ', errG.item())
                    print('KL_loss : ', kl_loss.item())
                    print('generator_lr : ', generator_lr)
                    print('discriminator_lr : ', discriminator_lr)
                    print('lr_decay_step : ', lr_decay_step)

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)

                        if cfg.CUDA:
                            lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                                netG, inputs, self.gpus)
                        else:
                            lr_fake, fake, _, _, _ = netG(
                                txt_embedding, noise, transf_matrices_inv,
                                label_one_hot)

                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
                if i % 100 == 0:
                    drive_count += 1
                    self.drive_summary_writer.add_summary(
                        summary_D, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_r, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_w, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_f, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_G, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_KL, drive_count)

            #print (" with torch.no_grad(): "  )
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 )
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                    #print (" inputs " , inputs )
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                #print (" lr_fake, fake " , lr_fake, fake )
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , )

                #print (" lr_fake is not None: " , lr_fake is not None )
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
                    #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " )
                    #end_t = time.time()
                    #print ("batch time : " , (end_t - start_t))
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)

            print("keyTime |||||||||||||||||||||||||||||||")
            print("epoch_time : ", time.time() - epoch_start_time)
            print("KeyTime |||||||||||||||||||||||||||||||")

        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()