Exemple #1
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD, start_epoch = self.load_network_stageI()
        else:
            netG, netD, start_epoch = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerG, optimizerD = self.load_optimizers(netG, netD)

        count = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.float().cuda()

                ######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ######################################################
                # (3) Update D network
                ######################################################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ######################################################
                # (2) Update G network
                ######################################################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
                save_optimizer(optimizerG, optimizerD, self.model_dir)

        save_model(netG, netD, self.max_epoch, self.model_dir)
        self.summary_writer.close()
Exemple #2
0
	def train(self, imageloader, storyloader, testloader):
		self.imageloader = imageloader
		self.testloader = testloader
		self.imagedataset = None
		self.testdataset = None
		netG, netD_im, netD_st = self.load_networks()

		im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1))
		im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0))
		st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1))
		st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0))
		if cfg.CUDA:
			im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda()
			st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda()

		generator_lr = cfg.TRAIN.GENERATOR_LR
		discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR

		lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
		im_optimizerD = \
			optim.Adam(netD_im.parameters(),
					   lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

		st_optimizerD = \
			optim.Adam(netD_st.parameters(),
					   lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

		netG_para = []
		for p in netG.parameters():
			if p.requires_grad:
				netG_para.append(p)
		optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR,
								betas=(0.5, 0.999))
		if self.tensorboard:
			self.build_tensorboard()
		loss = {}
		step = 0
		torch.save({
			'netG': netG, 
			'netD_im': netD_im,
			'netD_st': netD_st,
		}, os.path.join(self.model_dir, 'barebone.pth'))

		for epoch in range(self.max_epoch):
			start_t = time.time()
			if epoch % lr_decay_step == 0 and epoch > 0:
				generator_lr *= 0.5
				for param_group in optimizerG.param_groups:
					param_group['lr'] = generator_lr
				discriminator_lr *= 0.5
				for param_group in st_optimizerD.param_groups:
					param_group['lr'] = discriminator_lr
				for param_group in im_optimizerD.param_groups:
					param_group['lr'] = discriminator_lr
				loss.update({
					'D/lr': discriminator_lr,
					'G/lr': generator_lr,
				})

			print('Epoch [{}/{}]:'.format(epoch, self.max_epoch))
			with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar:
				for i, data in enumerate(storyloader, 0):
					######################################################
					# (1) Prepare training data
					######################################################
					im_batch = self.sample_real_image_batch()
					st_batch = data

					im_real_cpu = im_batch['images']
					im_motion_input = im_batch['description']
					im_content_input = im_batch['content']
					im_content_input = im_content_input.mean(1).squeeze()
					im_catelabel = im_batch['label']
					im_real_imgs = Variable(im_real_cpu)
					im_motion_input = Variable(im_motion_input)
					im_content_input = Variable(im_content_input)

					st_real_cpu = st_batch['images']
					st_motion_input = st_batch['description']
					st_content_input = st_batch['description']
					st_catelabel = st_batch['label']
					st_real_imgs = Variable(st_real_cpu)
					st_motion_input = Variable(st_motion_input)
					st_content_input = Variable(st_content_input)

					if cfg.CUDA:
						st_real_imgs = st_real_imgs.cuda()
						im_real_imgs = im_real_imgs.cuda()
						st_motion_input = st_motion_input.cuda()
						im_motion_input = im_motion_input.cuda()
						st_content_input = st_content_input.cuda()
						im_content_input = im_content_input.cuda()
						im_catelabel = im_catelabel.cuda()
						st_catelabel = st_catelabel.cuda()
					#######################################################
					# (2) Generate fake stories and images
					######################################################
					with torch.no_grad():
						im_inputs = (im_motion_input, im_content_input)
						_, im_fake, im_mu, im_logvar = netG.sample_images(*im_inputs)

						st_inputs = (st_motion_input, st_content_input)
						_, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos(*st_inputs)


					############################
					# (3) Update D network
					###########################
					netD_im.zero_grad()
					netD_st.zero_grad()

					im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \
						compute_discriminator_loss(netD_im, im_real_imgs, im_fake,
												im_real_labels, im_fake_labels, im_catelabel,
												im_mu, self.gpus)

					st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \
						compute_discriminator_loss(netD_st, st_real_imgs, st_fake,
												st_real_labels, st_fake_labels, st_catelabel,
												c_mu, self.gpus)

					loss.update({
						'D/story/loss': st_errD.data,
						'D/story/real_loss': st_errD_real.data,
						'D/story/fake_loss': st_errD_fake.data,
						'D/image/accuracy': accD,
						'D/image/loss': im_errD.data,
						'D/image/real_loss': im_errD_real.data,
						'D/image/fake_loss': im_errD_fake.data,
					})

					im_errD.backward()
					st_errD.backward()

					im_optimizerD.step()
					st_optimizerD.step()


					############################
					# (2) Update G network
					###########################
					for g_iter in range(2):
						netG.zero_grad()

						_, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos(
							st_motion_input, st_content_input)

						_, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input)

						im_errG, accG = compute_generator_loss(netD_im, im_fake,
													im_real_labels, im_catelabel, im_mu, self.gpus)
						st_errG, _ = compute_generator_loss(netD_st, st_fake,
													st_real_labels, st_catelabel, c_mu, self.gpus)
						im_kl_loss = KL_loss(im_mu, im_logvar)
						st_kl_loss = KL_loss(m_mu, m_logvar)
						errG = im_errG + self.ratio * st_errG

						kl_loss = im_kl_loss + self.ratio * st_kl_loss
						loss.update({
							'G/loss': im_errG.data,
							'G/kl': kl_loss.data,
						})
						errG_total = im_errG + self.ratio * st_errG + kl_loss
						errG_total.backward()
						optimizerG.step()
					if self.writer:
						for key, value in loss.items():
							self.writer.add_scalar(key,  value,  step)

					step += 1
					pbar.update(1)

					if i % 100 == 0:
						# save the image result for each epoch
						lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input)
						save_story_results(st_real_cpu, fake, epoch, self.image_dir, writer=self.writer, steps=step)
						if lr_fake is not None:
							save_story_results(None, lr_fake, epoch, self.image_dir, writer=self.writer, steps=step)

			end_t = time.time()
			print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
					 Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
					 accG: %.4f accD: %.4f
					 Total Time: %.2fsec
				  '''
				  % (epoch, self.max_epoch, i, len(storyloader),
					 st_errD.data, st_errG.data,
					 st_errD_real, st_errD_wrong, st_errD_fake, accG, accD,
					 (end_t - start_t)))

			if epoch % self.snapshot_interval == 0:
				save_model(netG, netD_im, netD_st, epoch, self.model_dir)
				save_test_samples(netG, self.testloader, self.test_dir, writer=self.writer, steps=step)
		#
		save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
    def train(self, data_loader, stage=1):

        logger = Logger('./logs_CS_GAN')
        image_transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize([64, 64]),
            transforms.ToTensor()
        ])

        CT_update = 35 if cfg.CTModel == '' else 0
        print("Training CT model for ", CT_update)

        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        #######
        nz = cfg.Z_DIM if not cfg.CAP.USE else cfg.CAP.Z_DIM
        batch_size = self.batch_size
        flags = Variable(torch.cuda.FloatTensor([-1.0] * batch_size))
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     requires_grad=False)
        fixed_noise_test = \
            Variable(torch.FloatTensor(10, nz).normal_(0, 1),
                     requires_grad=False)

        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))

        #Gaussian noise input added to the input images to the disc
        noise_input = Variable(
            torch.zeros(batch_size, 3, cfg.FAKEIMSIZE, cfg.FAKEIMSIZE))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
            noise_input = noise_input.cuda()
            flags.cuda()

        epsilon = 0.999
        epsilon_decay = 0.99
        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR,betas=(0.5, 0.999))
        netG_para = []

        # self.emb_model=EMB(512,128)
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####Optimizers for CT c                             ##########################TODO:PRINT PARAMETERS!!!!
        optimizerCTallmodel = optim.Adam(self.CTallmodel.parameters(),
                                         lr=0.0001,
                                         weight_decay=0.00001,
                                         betas=(0.5, 0.999))
        optimizerCTenc = optim.Adam(self.CTencoder.parameters(),
                                    lr=0.0001,
                                    weight_decay=0.00001,
                                    betas=(0.5, 0.999))
        count = 0
        len_dataset = len(data_loader)

        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
            print("Started training for new epoch")
            optimizerCTallmodel.zero_grad()
            ct_epoch_loss = 0
            emb_loss = 0
            epoch_count = 0
            for i, data in enumerate(data_loader):
                ######################################################
                # (1) Prepare training data
                ######################################################

                real_img_cpu, sentences, paddedArrayPrev, maskArrayPrev, paddedArrayCurr, Currlenghts, paddedArrayNext, maskArrayNext = data
                self.CTallmodel.encoder.hidden = self.CTallmodel.encoder.hidden_init(
                    paddedArrayCurr.size(1))
                real_imgs = Variable(real_img_cpu)
                paddedArrayCurr = Variable(
                    paddedArrayCurr.type(torch.LongTensor))
                paddedArrayNext_input = Variable(paddedArrayNext[:-1, :].type(
                    torch.LongTensor))
                paddedArrayPrev_input = Variable(paddedArrayPrev[:-1, :].type(
                    torch.LongTensor))

                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    paddedArrayCurr = paddedArrayCurr.cuda()
                    paddedArrayNext_input = paddedArrayNext_input.cuda()
                    paddedArrayPrev_input = paddedArrayPrev_input.cuda()
                inputs_CT = (paddedArrayCurr, Currlenghts,
                             paddedArrayPrev_input, paddedArrayNext_input)
                # sent_hidden, logits_prev, logits_next = self.CTallmodel(paddedArrayCurr, Currlenghts, paddedArrayPrev_input, paddedArrayNext_input)
                sent_hidden, logits_prev, logits_next = nn.parallel.data_parallel(
                    self.CTallmodel, inputs_CT, self.gpus)
                #Optimizing over Concurrent model
                if (epoch < CT_update):
                    logits_prev = logits_prev.contiguous().view(
                        -1,
                        logits_prev.size()[2])
                    logits_next = logits_next.contiguous().view(
                        -1,
                        logits_next.size()[2])

                    Y_prev = paddedArrayPrev[1:, :]
                    Y_prev = Y_prev.contiguous().view(-1)

                    Y_next = paddedArrayNext[1:, :]
                    Y_next = Y_next.contiguous().view(-1)

                    maskArrayPrev = maskArrayPrev[1:, :]
                    maskArrayPrev = maskArrayPrev.contiguous().view(-1)

                    maskArrayNext = maskArrayNext[1:, :]
                    maskArrayNext = maskArrayNext.contiguous().view(-1)

                    ind_prev = torch.nonzero(maskArrayPrev, out=None).squeeze()
                    ind_next = torch.nonzero(maskArrayNext, out=None).squeeze()

                    if torch.cuda.is_available():
                        ind_prev = ind_prev.cuda()
                        ind_next = ind_next.cuda()

                    valid_target_prev = torch.index_select(
                        Y_prev, 0,
                        ind_prev.type(torch.LongTensor)).type(torch.LongTensor)
                    valid_output_prev = torch.index_select(
                        logits_prev, 0, Variable(ind_prev))

                    valid_target_next = torch.index_select(
                        Y_next, 0,
                        ind_next.type(torch.LongTensor)).type(torch.LongTensor)
                    valid_output_next = torch.index_select(
                        logits_next, 0, Variable(ind_next))

                    if torch.cuda.is_available():
                        valid_output_prev = valid_output_prev.cuda()
                        valid_output_next = valid_output_next.cuda()

                        valid_target_prev = valid_target_prev.cuda()
                        valid_target_next = valid_target_next.cuda()

                    loss_prev = self.CTloss(valid_output_prev,
                                            Variable(valid_target_prev))
                    loss_next = self.CTloss(valid_output_next,
                                            Variable(valid_target_next))

                    self.CTallmodel.zero_grad()
                    optimizerCTallmodel.zero_grad()
                    loss = loss_prev + loss_next
                    loss.backward(retain_graph=True)
                    ct_epoch_loss += loss.data[0]
                    nn.utils.clip_grad_norm(self.CTallmodel.parameters(), 0.25)
                    optimizerCTallmodel.step()

                if epoch >= CT_update:
                    #######################################################
                    # (2) Generate fake images
                    ######################################################
                    noise.data.normal_(0, 1)
                    inputs = (sent_hidden, noise)
                    _, fake_imgs, mu, logvar = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus) #### TODO: Check Shapes->Checked

                    # _,fake_imgs,mu,logvar=netG(inputs[0],inputs[1])
                    #######################################################
                    # (2.1) Generate captions for fake images
                    ######################################################
                    if self.cap_model_bool:
                        sents, h_sent = self.eval_utils.captioning_model(
                            fake_imgs, self.cap_model, self.vocab_cap,
                            self.my_resnet, self.eval_kwargs)
                        h_sent_var = Variable(torch.FloatTensor(h_sent)).cuda()
                    # input_layer = tf.stack([preprocess_for_train(i) for i in real_imgs], axis=0)
                    real_imgs = Variable(
                        torch.stack([
                            image_transform_train(img.data.cpu()).cuda()
                            for img in real_imgs
                        ],
                                    dim=0))

                    ############################
                    # (3) Update D network
                    ###########################

                    if random.uniform(0, 1) < epsilon and cfg.GAN.ADD_NOISE:
                        epsilon *= epsilon_decay
                        noise_input.data.normal_(0, 1)
                        fake_imgs = fake_imgs + noise_input
                        real_imgs = real_imgs + noise_input
                    netD.zero_grad()
                    errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                    errD.backward()
                    optimizerD.step()

                    #Label Switching
                    #Trick as of - https://github.com/soumith/ganhacks/issues/14
                    #   if random.uniform(0.1)<epsilon:
                    # netD.zero_grad()
                    # errD, errD_real, errD_wrong, errD_fake = \
                    # compute_discriminator_loss(netD, real_imgs, fake_imgs,
                    #                            fake_labels, real_labels,
                    #                            mu, self.gpus)
                    # errD.backward()
                    # optimizerD.step()
                    ############################
                    # (4) Update G network
                    ###########################
                    if self.cap_model_bool:
                        loss_cos = self.cosEmbLoss(sent_hidden, h_sent_var,
                                                   flags)
                    netG.zero_grad()
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  mu, self.gpus)
                    kl_loss = KL_loss(mu, logvar)

                    if self.cap_model_bool:
                        errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + 10 * loss_cos
                        emb_loss += loss_cos.data[0]
                    else:
                        errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                    errG_total.backward()
                    optimizerG.step()

                count = count + 1
                epoch_count += 1
                if i % 200 == 0:
                    print("Loss CT Model: ", ct_epoch_loss / epoch_count)
                    # print("Emb Loss: ", emb_loss)
            # save the image result for each epoch after embedding model has been trained
            if epoch >= CT_update:
                inputs = (sent_hidden, fixed_noise)

                lr_fake, fake, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                if self.cap_model_bool:
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir,
                                     sentences, sents)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir,
                                         sentences, sents)
                else:
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir,
                                     sentences, None)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir,
                                         sentences, None)
                self.test(netG, fixed_noise_test, epoch)

                end_t = time.time()

                print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
                # logger.scalar_summary('Cosine_loss', emb_loss, epoch+1)
                logger.scalar_summary('errD_loss', errD.data[0] / len_dataset,
                                      epoch + 1)
                logger.scalar_summary('errG_loss', errG.data[0] / len_dataset,
                                      epoch + 1)
                logger.scalar_summary('kl_loss', kl_loss.data[0] / len_dataset,
                                      epoch + 1)

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, self.CTallmodel, epoch, self.model_dir)
            logger.scalar_summary('CT_loss', ct_epoch_loss / len_dataset,
                                  epoch + 1)

        save_model(netG, netD, self.CTallmodel, self.max_epoch, self.model_dir)
Exemple #4
0
    def train(self, imageloader, storyloader, testloader, stage=1):
        c_time = time.time()
        self.imageloader = imageloader
        self.imagedataset = None

        netG, netD_im, netD_st, netD_se = self.load_network_stageI()
        start = time.time()
        # Initial Labels
        im_real_labels = Variable(
            torch.FloatTensor(self.imbatch_size).fill_(1))
        im_fake_labels = Variable(
            torch.FloatTensor(self.imbatch_size).fill_(0))
        st_real_labels = Variable(
            torch.FloatTensor(self.stbatch_size).fill_(1))
        st_fake_labels = Variable(
            torch.FloatTensor(self.stbatch_size).fill_(0))
        if cfg.CUDA:
            im_real_labels, im_fake_labels = im_real_labels.cuda(
            ), im_fake_labels.cuda()
            st_real_labels, st_fake_labels = st_real_labels.cuda(
            ), st_fake_labels.cuda()

        use_segment = cfg.SEGMENT_LEARNING
        segment_weight = cfg.SEGMENT_RATIO
        image_weight = cfg.IMAGE_RATIO

        # Optimizer and Scheduler
        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        im_optimizerD = optim.Adam(netD_im.parameters(),
                                   lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                   betas=(0.5, 0.999))
        st_optimizerD = optim.Adam(netD_st.parameters(),
                                   lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                   betas=(0.5, 0.999))
        if use_segment:
            se_optimizerD = optim.Adam(netD_se.parameters(),
                                       lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                       betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        mse_loss = nn.MSELoss()

        scheduler_imD = ReduceLROnPlateau(im_optimizerD,
                                          'min',
                                          verbose=True,
                                          factor=0.5,
                                          min_lr=1e-7,
                                          patience=0)
        scheduler_stD = ReduceLROnPlateau(st_optimizerD,
                                          'min',
                                          verbose=True,
                                          factor=0.5,
                                          min_lr=1e-7,
                                          patience=0)
        if use_segment:
            scheduler_seD = ReduceLROnPlateau(se_optimizerD,
                                              'min',
                                              verbose=True,
                                              factor=0.5,
                                              min_lr=1e-7,
                                              patience=0)
        scheduler_G = ReduceLROnPlateau(optimizerG,
                                        'min',
                                        verbose=True,
                                        factor=0.5,
                                        min_lr=1e-7,
                                        patience=0)
        count = 0

        # Start training
        if not self.con_ckpt:
            start_epoch = 0
        else:
            start_epoch = int(self.con_ckpt)
        # self.calculate_vfid(netG, 0, testloader)

        print('LR DECAY EPOCH: {}'.format(lr_decay_step))
        for epoch in range(start_epoch, self.max_epoch):
            l = self.ratio * (2. / (1. + np.exp(-10. * epoch)) - 1)
            start_t = time.time()

            # Adjust lr
            num_step = len(storyloader)
            stats = {}

            with tqdm(total=len(storyloader), dynamic_ncols=True) as pbar:
                for i, data in enumerate(storyloader):
                    ######################################################
                    # (1) Prepare training data
                    ######################################################
                    im_batch = self.sample_real_image_batch()
                    st_batch = data
                    im_real_cpu = im_batch['images']
                    im_motion_input = im_batch[
                        'description'][:, :cfg.TEXT.
                                       DIMENSION]  # description vector and arrtibute (60, 356)
                    im_content_input = im_batch[
                        'content'][:, :, :cfg.TEXT.
                                   DIMENSION]  # description vector and attribute for every story (60,5,356)
                    im_real_imgs = Variable(im_real_cpu)
                    im_motion_input = Variable(im_motion_input)
                    im_content_input = Variable(im_content_input)
                    im_labels = Variable(im_batch['labels'])

                    st_real_cpu = st_batch['images']
                    st_motion_input = st_batch[
                        'description'][:, :, :cfg.TEXT.DIMENSION]  #(12,5,356)
                    st_content_input = st_batch[
                        'description'][:, :, :cfg.TEXT.DIMENSION]  # (12,5,356)
                    st_texts = None
                    if 'text' in st_batch:
                        st_texts = st_batch['text']
                    st_real_imgs = Variable(st_real_cpu)
                    st_motion_input = Variable(st_motion_input)
                    st_content_input = Variable(st_content_input)
                    st_labels = Variable(st_batch['labels'])  # (12,5,9)
                    if use_segment:
                        se_real_cpu = im_batch['images_seg']
                        se_real_imgs = Variable(se_real_cpu)

                    if cfg.CUDA:
                        st_real_imgs = st_real_imgs.cuda()  # (12,3,5,64,64)
                        im_real_imgs = im_real_imgs.cuda()
                        st_motion_input = st_motion_input.cuda()
                        im_motion_input = im_motion_input.cuda()
                        st_content_input = st_content_input.cuda()
                        im_content_input = im_content_input.cuda()
                        im_labels = im_labels.cuda()
                        st_labels = st_labels.cuda()
                        if use_segment:
                            se_real_imgs = se_real_imgs.cuda()
                    im_motion_input = torch.cat((im_motion_input, im_labels),
                                                1)  # 356+9=365 (60,365)
                    st_motion_input = torch.cat((st_motion_input, st_labels),
                                                2)  # (12,5,365)

                    #######################################################
                    # (2) Generate fake stories and images
                    ######################################################
                    # print(st_motion_input.shape, im_motion_input.shape)

                    with torch.no_grad():
                        _, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = \
                            netG.sample_videos(st_motion_input, st_content_input) # m_mu (60,365), c_mu (12,124)

                        _, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = \
                            netG.sample_images(im_motion_input, im_content_input, seg=use_segment) # im_mu (60,489), cim_mu (60,124)

                    characters_mu = (
                        st_labels.mean(1) > 0
                    ).type(torch.FloatTensor).cuda(
                    )  # which character exists in the full story (5 descriptions)
                    st_mu = torch.cat(
                        (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean(
                            1).squeeze(), characters_mu), 1)
                    #  124 + 356 + 9 = 489 (12,489), get character info form whole story

                    im_mu = torch.cat((im_motion_input, cim_mu), 1)
                    # (60,489)
                    ############################
                    # (3) Update D network
                    ###########################

                    netD_im.zero_grad()
                    netD_st.zero_grad()
                    se_accD = 0
                    if use_segment:
                        netD_se.zero_grad()
                        se_errD, se_errD_real, se_errD_wrong, se_errD_fake, se_accD, _ = \
                            compute_discriminator_loss(netD_se, se_real_imgs, se_fake,
                                                im_real_labels, im_fake_labels, im_labels,
                                                im_mu, self.gpus)

                    im_errD, im_errD_real, im_errD_wrong, im_errD_fake, im_accD, _ = \
                        compute_discriminator_loss(netD_im, im_real_imgs, im_fake,
                                               im_real_labels, im_fake_labels, im_labels,
                                               im_mu, self.gpus)

                    st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _, order_consistency  = \
                        compute_discriminator_loss(netD_st, st_real_imgs, st_fake,
                                               st_real_labels, st_fake_labels, st_labels,
                                               st_mu, self.gpus)

                    if use_segment:
                        se_errD.backward()
                        se_optimizerD.step()
                        stats.update({
                            'seg_D/loss': se_errD.data,
                            'seg_D/real': se_errD_real,
                            'seg_D/fake': se_errD_fake,
                        })

                    im_errD.backward()
                    st_errD.backward()

                    im_optimizerD.step()
                    st_optimizerD.step()

                    stats.update({
                        'img_D/loss': im_errD.data,
                        'img_D/real': im_errD_real,
                        'img_D/fake': im_errD_fake,
                        'Accuracy/im_D': im_accD,
                        'Accuracy/se_D': se_accD,
                    })

                    step = i + num_step * epoch
                    self._logger.add_scalar('st_D/loss', st_errD.data, step)
                    self._logger.add_scalar('st_D/real', st_errD_real, step)
                    self._logger.add_scalar('st_D/fake', st_errD_fake, step)
                    self._logger.add_scalar('st_D/order', order_consistency,
                                            step)

                    ############################
                    # (2) Update G network
                    ###########################
                    netG.zero_grad()
                    video_latents, st_fake, m_mu, m_logvar, c_mu, c_logvar, _ = netG.sample_videos(
                        st_motion_input, st_content_input)
                    image_latents, im_fake, im_mu, im_logvar, cim_mu, cim_logvar, se_fake = netG.sample_images(
                        im_motion_input, im_content_input, seg=use_segment)
                    encoder_decoder_loss = 0
                    if video_latents is not None:
                        ((h_seg1, h_seg2, h_seg3, h_seg4),
                         (g_seg1, g_seg2, g_seg3, g_seg4)) = video_latents

                        video_latent_loss = mse_loss(
                            g_seg1,
                            h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss(
                                g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4)
                        ((h_seg1, h_seg2, h_seg3, h_seg4),
                         (g_seg1, g_seg2, g_seg3, g_seg4)) = image_latents
                        image_latent_loss = mse_loss(
                            g_seg1,
                            h_seg1) + mse_loss(g_seg2, h_seg2) + mse_loss(
                                g_seg3, h_seg3) + mse_loss(g_seg4, h_seg4)
                        encoder_decoder_loss = (image_latent_loss +
                                                video_latent_loss) / 2

                        reconstruct_img = netG.train_autoencoder(se_real_imgs)
                        reconstruct_fake = netG.train_autoencoder(se_fake)
                        reconstruct_loss = (
                            mse_loss(reconstruct_img, se_real_imgs) +
                            mse_loss(reconstruct_fake, se_fake)) / 2.0

                        self._logger.add_scalar('G/image_vae_loss',
                                                image_latent_loss.data, step)
                        self._logger.add_scalar('G/video_vae_loss',
                                                video_latent_loss.data, step)
                        self._logger.add_scalar('G/reconstruct_loss',
                                                reconstruct_loss.data, step)

                    characters_mu = (st_labels.mean(1) > 0).type(
                        torch.FloatTensor).cuda()
                    st_mu = torch.cat(
                        (c_mu, st_motion_input[:, :, :cfg.TEXT.DIMENSION].mean(
                            1).squeeze(), characters_mu), 1)

                    im_mu = torch.cat((im_motion_input, cim_mu), 1)
                    se_errG, se_errG, se_accG = 0, 0, 0
                    if use_segment:
                        se_errG, se_accG, _ = compute_generator_loss(
                            netD_se, se_fake, se_real_imgs, im_real_labels,
                            im_labels, im_mu, self.gpus)

                    im_errG, im_accG, _ = compute_generator_loss(
                        netD_im, im_fake, im_real_imgs, im_real_labels,
                        im_labels, im_mu, self.gpus)

                    st_errG, st_accG, G_consistency = compute_generator_loss(
                        netD_st, st_fake, st_real_imgs, st_real_labels,
                        st_labels, st_mu, self.gpus)
                    ######
                    # Sample Image Loss and Sample Video Loss
                    im_kl_loss = KL_loss(cim_mu, cim_logvar)
                    st_kl_loss = KL_loss(c_mu, c_logvar)

                    errG = im_errG + self.ratio * (
                        image_weight * st_errG + se_errG * segment_weight
                    )  # for record
                    kl_loss = im_kl_loss + self.ratio * st_kl_loss  # for record

                    # Total Loss
                    errG_total = im_errG + im_kl_loss * cfg.TRAIN.COEFF.KL \
                        + self.ratio * (se_errG*segment_weight + st_errG*image_weight + st_kl_loss * cfg.TRAIN.COEFF.KL)

                    if video_latents is not None:
                        errG_total += (video_latent_loss +
                                       reconstruct_loss) * cfg.RECONSTRUCT_LOSS

                    errG_total.backward()
                    optimizerG.step()
                    stats.update({
                        'G/loss': errG_total.data,
                        'G/im_KL': im_kl_loss.data,
                        'G/st_KL': st_kl_loss.data,
                        'G/KL': kl_loss.data,
                        'G/consistency': G_consistency,
                        'Accuracy/im_G': im_accG,
                        'Accuracy/se_G': se_accG,
                        'Accuracy/st_G': st_accG,
                        'G/gan_loss': errG.data,
                    })

                    count = count + 1
                    pbar.update(1)

                    if i % 20 == 0:
                        step = i + num_step * epoch
                        for key, value in stats.items():
                            self._logger.add_scalar(key, value, step)

            with torch.no_grad():
                lr_fake, fake, _, _, _, _, se_fake = netG.sample_videos(
                    st_motion_input, st_content_input, seg=use_segment)
                st_result = save_story_results(st_real_cpu, fake, st_texts,
                                               epoch, self.image_dir, i)
                if use_segment and se_fake is not None:
                    se_result = save_image_results(None, se_fake)
            self._logger.add_image("pororo",
                                   st_result.transpose(2, 0, 1) / 255, epoch)
            if use_segment:
                self._logger.add_image("segment",
                                       se_result.transpose(2, 0, 1) / 255,
                                       epoch)

            # Adjust lr
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in st_optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
                for param_group in im_optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
                lr_decay_step *= 2

            g_lr, im_lr, st_lr = 0, 0, 0
            for param_group in optimizerG.param_groups:
                g_lr = param_group['lr']
            for param_group in st_optimizerD.param_groups:
                st_lr = param_group['lr']
            for param_group in im_optimizerD.param_groups:
                im_lr = param_group['lr']
            self._logger.add_scalar('learning/generator', g_lr, epoch)
            self._logger.add_scalar('learning/st_discriminator', st_lr, epoch)
            self._logger.add_scalar('learning/im_discriminator', im_lr, epoch)

            if cfg.EVALUATE_FID_SCORE:
                self.calculate_vfid(netG, epoch, testloader)

            #self.calculate_ssim(netG, epoch, testloader)
            time_mins = int((time.time() - c_time) / 60)
            time_hours = int(time_mins / 60)
            epoch_mins = int((time.time() - start_t) / 60)
            epoch_hours = int(epoch_mins / 60)

            print(
                "----[{}/{}]Epoch time:{} hours {} mins, Total time:{} hours----"
                .format(epoch, self.max_epoch, epoch_hours, epoch_mins,
                        time_hours))

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD_im, netD_st, netD_se, epoch,
                           self.model_dir)
                #save_test_samples(netG, testloader, self.test_dir)
        save_model(netG, netD_im, netD_st, netD_se, self.max_epoch,
                   self.model_dir)
Exemple #5
0
    def train(self, imageloader, storyloader, testloader):
        self.imageloader = imageloader
        self.testloader = testloader
        self.imagedataset = None
        self.testdataset = None
        netG, netD_im, netD_st = self.load_networks()
       
        
        im_real_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(1))
        im_fake_labels = Variable(torch.FloatTensor(self.imbatch_size).fill_(0))
        st_real_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(1))
        st_fake_labels = Variable(torch.FloatTensor(self.stbatch_size).fill_(0))
        if cfg.CUDA:
            im_real_labels, im_fake_labels = im_real_labels.cuda(), im_fake_labels.cuda()
            st_real_labels, st_fake_labels = st_real_labels.cuda(), st_fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR

        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        im_optimizerD = \
            optim.Adam(netD_im.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

        st_optimizerD = \
            optim.Adam(netD_st.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in st_optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
                for param_group in im_optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr



            for i, data in enumerate(storyloader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                im_batch = self.sample_real_image_batch()
                st_batch = data

                im_real_cpu = im_batch['images']
                im_motion_input = im_batch['description']
                im_content_input = im_batch['content']
                im_content_input = im_content_input.mean(1).squeeze()
                im_catelabel = im_batch['label']
                im_real_imgs = Variable(im_real_cpu)
                im_motion_input = Variable(im_motion_input)
                im_content_input = Variable(im_content_input)

                st_real_cpu = st_batch['images']
                st_motion_input = st_batch['description']
                st_content_input = st_batch['description']
                st_catelabel = st_batch['label']
                st_real_imgs = Variable(st_real_cpu)
                st_motion_input = Variable(st_motion_input)
                st_content_input = Variable(st_content_input)

                if cfg.CUDA:
                    st_real_imgs = st_real_imgs.cuda()
                    im_real_imgs = im_real_imgs.cuda()
                    st_motion_input = st_motion_input.cuda()
                    im_motion_input = im_motion_input.cuda()
                    st_content_input = st_content_input.cuda()
                    im_content_input = im_content_input.cuda()
                    im_catelabel = im_catelabel.cuda()
                    st_catelabel = st_catelabel.cuda()
                #######################################################
                # (2) Generate fake stories and images
                ######################################################
               
                # im_inputs = (im_motion_input, im_content_input)
                # _, im_fake, im_mu, im_logvar =\
                #     nn.parallel.data_parallel(netG.sample_images, im_inputs, self.gpus)
                # st_inputs = (st_motion_input, st_content_input)
                # _, st_fake, c_mu, c_logvar, m_mu, m_logvar = \
                #     nn.parallel.data_parallel(netG.sample_videos, st_inputs, self.gpus)

                im_inputs = (im_motion_input, im_content_input)
                _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input)
                st_inputs = (st_motion_input, st_content_input)
                _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos( st_motion_input, st_content_input)

                ############################
                # (3) Update D network
                ###########################
                netD_im.zero_grad()
                netD_st.zero_grad()
              
                im_errD, im_errD_real, im_errD_wrong, im_errD_fake, accD = \
                    compute_discriminator_loss(netD_im, im_real_imgs, im_fake,
                                               im_real_labels, im_fake_labels, im_catelabel, 
                                               im_mu, self.gpus)

                st_errD, st_errD_real, st_errD_wrong, st_errD_fake, _ = \
                    compute_discriminator_loss(netD_st, st_real_imgs, st_fake,
                                               st_real_labels, st_fake_labels, st_catelabel, 
                                               c_mu, self.gpus)


                im_errD.backward()
                st_errD.backward()
               
                im_optimizerD.step()
                st_optimizerD.step()


                ############################
                # (2) Update G network
                ###########################
                for g_iter in range(2):
                    netG.zero_grad()

                    _, st_fake, c_mu, c_logvar, m_mu, m_logvar = netG.sample_videos(
                        st_motion_input, st_content_input)

                    # st_mu = m_mu.view(cfg.TRAIN.ST_BATCH_SIZE, cfg.VIDEO_LEN, m_mu.shape[1])
                    # st_mu = st_mu.contiguous().view(-1, cfg.VIDEO_LEN * m_mu.shape[1])

                    _, im_fake, im_mu, im_logvar = netG.sample_images(im_motion_input, im_content_input)

                    im_errG, accG = compute_generator_loss(netD_im, im_fake,
                                                  im_real_labels, im_catelabel, im_mu, self.gpus)
                    st_errG, _ = compute_generator_loss(netD_st, st_fake,
                                                  st_real_labels, st_catelabel, c_mu, self.gpus)
                    im_kl_loss = KL_loss(im_mu, im_logvar)
                    st_kl_loss = KL_loss(m_mu, m_logvar)
                    errG = im_errG + self.ratio * st_errG

                    kl_loss = im_kl_loss + self.ratio * st_kl_loss
                    errG_total = im_errG + self.ratio * st_errG + kl_loss
                    errG_total.backward()
                    optimizerG.step()

                if i % 100 == 0:
                    # save the image result for each epoch
                    lr_fake, fake, _, _, _, _ = netG.sample_videos(st_motion_input, st_content_input)
                    save_story_results(st_real_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_story_results(None, lr_fake, epoch, self.image_dir)

            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     accG: %.4f accD: %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(storyloader),
                     st_errD.data, st_errG.data,
                     st_errD_real, st_errD_wrong, st_errD_fake, accG, accD,
                     (end_t - start_t)))

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD_im, netD_st, epoch, self.model_dir)
                save_test_samples(netG, self.testloader, self.test_dir)
        #
        save_model(netG, netD_im, netD_st, self.max_epoch, self.model_dir)
Exemple #6
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                # label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0)
                label_one_hot = torch.FloatTensor(noise.shape[0], max_objects,
                                                  81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)
                        lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                            netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemple #7
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        ####
        #netD_std = 0.1
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5

                #### stand deviation decay
                #netD_std *= 0.5

                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding).float()
                #	print(txt_embedding.size())
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                if stage == 1:
                    _, fake_imgs, mu, logvar, _ = \
                 nn.parallel.data_parallel(netG, inputs, self.gpus)
                else:
                    _, fake_imgs, mu, logvar = \
                                      nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                #### A little noise for images passed to discriminator
                #fake_imgs = fake_imgs + torch.cuda.FloatTensor(fake_imgs.size()).normal_(0,netD_std)

                #### update D twice
                for D_update in range(2):
                    netD.zero_grad()
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                            real_labels, fake_labels,
                                            mu, self.gpus)
                    errD.backward()
                    optimizerD.step()

                #### update D with reversed labels
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                       compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                fake_labels, real_labels,
                                mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    #summary_D = summary.scalar('D_loss', errD.data[0])
                    #print(summary_D)
                    #summary_D_r = summary.scalar('D_loss_real', errD_real)
                    #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    #summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    #summary_G = summary.scalar('G_loss', errG.data[0])
                    #summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    #self.summary_writer.add_summary(summary_D, count)
                    #self.summary_writer.add_summary(summary_D_r, count)
                    #self.summary_writer.add_summary(summary_D_w, count)
                    #self.summary_writer.add_summary(summary_D_f, count)
                    #self.summary_writer.add_summary(summary_G, count)
                    #self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    if stage == 1:
                        lr_fake, fake, _, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                    else:
                        lr_fake, fake, _, _ = \
                                          nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemple #8
0
    def train(self, data_loader, dataset, stage=1):
        
        image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent = self.networks
                           
        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
            
        #
        # make labels for real/fake
        #
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))  # try discriminator smoothing
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        
        txt_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 
        img_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 
        
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
            txt_enc_labels = txt_enc_labels.cuda()
            img_enc_labels = img_enc_labels.cuda()                

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        
        optims = self.define_optimizers(image_encoder, image_generator, 
                                   text_encoder, text_generator, 
                                   disc_image, disc_latent)
        optim_img_enc, optim_img_gen, optim_txt_enc, optim_txt_gen, optim_disc_img, optim_disc_latent = optims
        
        count = 0
                
        for epoch in range(self.max_epoch):
            
            start_t = time.time()
            
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.75
                for param_group in optim_img_gen.param_groups:
                    param_group['lr'] = generator_lr
                    
                discriminator_lr *= 0.75
                for param_group in optim_disc_img.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                _, real_img_cpu, _, captions, pred_cap = data

                raw_inds, raw_lengths = self.process_captions(captions)
                
                inds, lengths = raw_inds.data, raw_lengths
                
                inds = Variable(inds)
                lens_sort, sort_idx = lengths.sort(0, descending=True)
                                
                # need to dataparallel the encoders?
                txt_encoder_output = text_encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None)
                encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output
                
                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()

                #######################################################
                # (2) Generate fake images and their latent codes
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (real_txt_code, noise)
                fake_imgs = \
                    nn.parallel.data_parallel(image_generator, inputs, self.gpus)
                                        
                fake_img_out = nn.parallel.data_parallel(
                    image_encoder, (fake_imgs), self.gpus
                )
            
                fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out
                fake_img_feats = fake_img_feats.transpose(0,1)                    
                    
                #######################################################
                # (2b) Calculate auto encoding loss for text
                ######################################################           
                loss_auto_txt, _ = compute_text_gen_loss(text_generator, 
                                                      inds[:,sort_idx],
                                                      real_txt_code.unsqueeze(0), 
                                                      encoder_out, 
                                                      self.txt_dico)
                loss_auto_txt = loss_auto_txt / lengths.float().sum() 

                #######################################################
                # (2c) Decode z from real imgs and calc auto-encoding loss
                ######################################################                    
                
                real_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus
                )
                
                real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out

                noise.data.normal_(0, 1)
                loss_auto_img, _ = compute_image_gen_loss(image_generator, 
                                                       real_imgs[sort_idx],
                                                       real_img_code,
                                                       noise,
                                                       self.gpus)
                
                #######################################################
                # (2c) Decode z from fake imgs and calc cycle loss
                ######################################################                    
                
                loss_cycle_text, gen_captions = compute_text_gen_loss(text_generator, 
                                                        inds[:,sort_idx], 
                                                        fake_img_code.unsqueeze(0), 
                                                        fake_img_feats, 
                                                        self.txt_dico)

                loss_cycle_text = loss_cycle_text / lengths.float().sum()
                
                ###############################################################
                # (2d) Generate image from predicted cap, calc img cycle loss
                ###############################################################
                
                loss_cycle_img = 0
                if (len(pred_cap)):
                    pred_inds, pred_lens = pred_cap
                    pred_inds = Variable(pred_inds.transpose(0,1))
                    pred_inds = pred_inds.cuda() if cfg.CUDA else pred_inds

                    pred_output = encoder(pred_inds[:, sort_idx], pred_lens.cpu().numpy(), None)
                    pred_txt_out, pred_txt_hidden, pred_txt_code, pred_txt_mu, pred_txt_logvar = pred_output
                  
                    noise.data.normal_(0, 1)
                    inputs = (pred_txt_code, noise)
                    _, fake_from_fake_img, mu, logvar = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                  
                    pred_img_out = nn.parallel.data_parallel(
                        image_encoder, (fake_from_fake_img), self.gpus
                    )                    
                  
                    pred_img_feats, pred_img_emb, pred_img_code, pred_img_mu, pred_img_logvar = pred_img_out
                  
                    semantic_target = Variable(torch.ones(batch_size))
                    if cfg.CUDA:
                        semantic_target = semantic_target.cuda()
                                              
                    loss_cycle_img = cosine_emb_loss(
                        pred_img_feats.contiguous().view(batch_size, -1), real_img_feats.contiguous().view(batch_size, -1), semantic_target
                    )
                
                ###########################
                # (3) Update D network
                ###########################
                optim_disc_img.zero_grad()
                optim_disc_latent.zero_grad()
                
                errD = 0
                
                errD_fake_imgs = compute_cond_discriminator_loss(disc_image, fake_imgs, 
                                                   fake_labels, encoder_hidden[0], self.gpus)               
                
                errD_im, errD_real, errD_fake = \
                    compute_uncond_discriminator_loss(disc_image, real_imgs, fake_imgs,
                                                      real_labels, fake_labels,
                                                      self.gpus)
                    
                err_latent_disc = compute_latent_discriminator_loss(disc_latent, 
                                                                    real_img_emb, encoder_hidden[0],
                                                                    img_enc_labels, txt_enc_labels,
                                                                    self.gpus)
                
                if (len(pred_cap)):
                    errD_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, 
                                                                 fake_labels, pred_txt_hidden[0], self.gpus)
                    errD += errD_fake_from_fake_imgs                 
                
                errD = errD + errD_im + errD_fake_imgs + err_latent_disc
                
                # check NaN
                if (errD != errD).data.any():
                    print("NaN detected (discriminator)")
                    pdb.set_trace()
                    exit()
                    
                errD.backward()
                                
                optim_disc_img.step()
                optim_disc_latent.step()
                
                ############################
                # (2) Update G network
                ###########################
                optim_img_enc.zero_grad()
                optim_img_gen.zero_grad()
                optim_txt_enc.zero_grad()
                optim_txt_gen.zero_grad()
                
                errG_total = 0
                
                err_g_uncond_loss = compute_uncond_generator_loss(disc_image, fake_imgs,
                                              real_labels, self.gpus)
                
                err_g_cond_disc_loss = compute_cond_generator_loss(disc_image, fake_imgs, 
                                                                   real_labels, encoder_hidden[0], self.gpus)
                                    
                err_latent_gen = compute_latent_generator_loss(disc_latent, 
                                                               real_img_emb, encoder_hidden[0],
                                                               img_enc_labels, txt_enc_labels,
                                                               self.gpus)
                
                errG = err_g_uncond_loss + err_g_cond_disc_loss + err_latent_gen + \
                        loss_cycle_text + \
                        loss_auto_img + \
                        loss_auto_txt
                
                if (len(pred_cap)):
                    errG_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, 
                                                                 real_labels, pred_txt_hidden[0], self.gpus)
                    errG += errG_fake_from_fake_imgs                
                
                img_kl_loss = KL_loss(real_img_mu, real_img_logvar)
                txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar)
                f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar)

                kl_loss = img_kl_loss + txt_kl_loss + f_img_kl_loss
                           
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                
                # check NaN
                if (errG_total != errG_total).data.any():
                    print("NaN detected (generator)")
                    pdb.set_trace()
                    exit()
                
                errG_total.backward()
                
                optim_img_enc.step()
                optim_img_gen.step()
                optim_txt_enc.step()
                optim_txt_gen.step()               
                
                count = count + 1
                if i % 100 == 0:
                    self.vis.add_to_plot("D_loss", np.asarray([[
                                                    errD.data[0],
                                                    errD_im.data[0],
                                                    errD_fake_imgs.data[0],
                                                    err_latent_disc.data[0]
                                                    ]]), 
                                                    np.asarray([[count] * 4]))
                    self.vis.add_to_plot("G_loss", np.asarray([[
                                                    errG.data[0], 
                                                    err_g_uncond_loss.data[0],
                                                    err_g_cond_disc_loss.data[0],
                                                    err_latent_gen.data[0],
                                                    loss_cycle_text.data[0],
                                                    loss_auto_img.data[0],
                                                    loss_auto_txt.data[0]
                                                    ]]),
                                                    np.asarray([[count] * 7]))
                    self.vis.add_to_plot("KL_loss", np.asarray([[
                                                    kl_loss.data[0],
                                                    img_kl_loss.data[0],
                                                    txt_kl_loss.data[0],
                                                    f_img_kl_loss.data[0]
                                                    ]]), 
                                         np.asarray([[count] * 4]))
                
                    self.vis.show_images("real_im", real_imgs[sort_idx].data.cpu().numpy())
                    self.vis.show_images("fake_im", fake_imgs.data.cpu().numpy())
                    
                    sorted_captions = [captions[i] for i in sort_idx.cpu().tolist()]
                    gen_cap_text = []
                    for d_i, d in enumerate(gen_captions):
                        s = u""
                        for i in d:
                            if i == self.txt_dico.EOS_TOKEN:
                                break
                            if i != self.txt_dico.SOS_TOKEN:
                                s += self.txt_dico.id2word[i] + u" "
                        gen_cap_text.append(s)
                        
                    self.vis.show_text("real_captions", sorted_captions)
                    self.vis.show_text("genr_captions", gen_cap_text)
                    
                    r_precision = self.evaluator.r_precision_score(fake_img_code, real_txt_code)
                    self.vis.add_to_plot("r_precision", np.asarray([r_precision.data[0]]), np.asarray([count]))
                                                        
                        
#             # save pred caps for next iteration
#             for i, data in enumerate(data_loader, 0):
#                 keys, real_img_cpu, _, _, _ = data
#                 real_imgs = Variable(real_img_cpu)
#                 if cfg.CUDA:
#                     real_imgs = real_imgs.cuda()                
                
#                 cap_img_out = nn.parallel.data_parallel(
#                     image_encoder, (real_imgs[sort_idx]), self.gpus
#                 )
                
#                 cap_img_feats, cap_img_emb, cap_img_code, cap_img_mu, cap_img_logvar = cap_img_out
#                 cap_img_feats = cap_img_feats.transpose(0,1)
                                                
#                 cap_features = cap_img_code.unsqueeze(0)
                
#                 cap_dec_inp = Variable(torch.LongTensor([self.txt_dico.SOS_TOKEN] * self.batch_size))
#                 cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp

#                 cap_dec_hidden = cap_features.detach()

#                 seq = torch.LongTensor([])
#                 seq = seq.cuda() if cfg.CUDA else seq

#                 max_target_length = 20
                
#                 lengths = torch.LongTensor(batch_size).fill_(20)

#                 for t in range(max_target_length):

#                     cap_dec_out, cap_dec_hidden, cap_dec_attn = decoder(
#                         cap_dec_inp, cap_dec_hidden, cap_img_feats
#                     )

#                     topv, topi = cap_dec_out.topk(1, dim=1)

#                     cap_dec_inp = topi #.squeeze(dim=2)
#                     cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp

#                     seq = torch.cat((seq, cap_dec_inp.data), dim=1)

#                 dataset.save_captions(keys, seq.cpu(), lengths.cpu())

            iscore_mu_real, _ = self.evaluator.inception_score(real_imgs[sort_idx])
            iscore_mu_fake, _ = self.evaluator.inception_score(fake_imgs)
            self.vis.add_to_plot("inception_score", np.asarray([[
                        iscore_mu_real,
                        iscore_mu_fake
                    ]]),
                    np.asarray([[epoch] * 2]))    
            
            end_t = time.time()
            
            prefix = "Epoch %d; %s, %.1f sec" % (epoch, time.strftime('D%d %X'), (end_t-start_t))
            gen_str = "G_total: %.3f Gen loss: %.3f KL loss %.3f" % (
                                                                         errG_total.data[0],
                                                                         errG.data[0],
                                                                         kl_loss.data[0]
                                                                        )
            
            dis_str = "Img Disc: %.3f Latent Disc: %.3f" % (
                errD.data[0], 
                err_latent_disc.data[0]
            )
            
            eval_str = "Incep real: %.3f Incep fake: %.3f R prec %.3f" % (
                iscore_mu_real, 
                iscore_mu_fake,
                r_precision
            )
                
            print("%s %s, %s; %s" % (prefix, gen_str, dis_str, eval_str))
            
            if epoch % self.snapshot_interval == 0:
                save_model(image_encoder, image_generator, 
                           text_encoder, text_generator, 
                           disc_image, disc_latent,
                           epoch, self.model_dir)

        save_model(image_encoder, image_generator, 
                   text_encoder, text_generator, 
                   disc_image, disc_latent, 
                   epoch, self.model_dir)
        
        self.summary_writer.close()
Exemple #9
0
    def train(self, data_loader, dataset, stage=1):

        netG, netD, encoder, decoder, image_encoder, enc_disc, clf_model = self.load_network_stageI(
        )

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(
            1))  # try discriminator smoothing
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        enc_params = filter(lambda p: p.requires_grad, encoder.parameters())
        enc_optimizer = optim_fn(enc_params, **optim_params)
        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        dec_params = filter(lambda p: p.requires_grad, decoder.parameters())
        dec_optimizer = optim_fn(dec_params, **optim_params)

        # image_enc_optimizer = \
        #     optim.Adam(image_encoder.parameters(),
        #                lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        image_enc_optimizer = \
            optim.SGD(image_encoder.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR)

        enc_disc_optimizer = \
            optim.Adam(enc_disc.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

        count = 0

        criterionCycle = nn.SmoothL1Loss()
        #criterionCycle = torch.nn.BCELoss()
        semantic_criterion = nn.CosineEmbeddingLoss()

        for epoch in range(self.max_epoch):

            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.75
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.75
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                _, real_img_cpu, _, captions, pred_cap = data

                raw_inds, raw_lengths = self.process_captions(captions)

                # need to fix noise addition
                #inds, lengths = self.add_noise(raw_inds.data, raw_lengths)
                inds, lengths = raw_inds.data, raw_lengths

                inds = Variable(inds)
                lens_sort, sort_idx = lengths.sort(0, descending=True)

                # need to dataparallel the encoders?
                txt_encoder_output = encoder(inds[:, sort_idx],
                                             lens_sort.cpu().numpy(), None)
                encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (real_txt_code, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (2b) Decode z from txt and calc auto-encoding loss
                ######################################################
                loss_auto = 0
                auto_dec_inp = Variable(
                    torch.LongTensor([self.txt_dico.SOS_TOKEN] *
                                     self.batch_size))
                auto_dec_inp = auto_dec_inp.cuda(
                ) if cfg.CUDA else auto_dec_inp
                auto_dec_hidden = real_txt_code.unsqueeze(0)

                max_target_length = inds.size(0)

                for t in range(max_target_length):

                    auto_dec_out, auto_dec_hidden, auto_dec_attn = decoder(
                        auto_dec_inp, auto_dec_hidden, encoder_out)

                    loss_auto = loss_auto + F.cross_entropy(
                        auto_dec_out,
                        inds[:, sort_idx][t],
                        ignore_index=self.txt_dico.PAD_TOKEN)
                    auto_dec_inp = inds[:, sort_idx][t]

                loss_auto = loss_auto / lengths.float().sum()

                #######################################################
                # (2c) Decode z from real imgs and calc auto-encoding loss
                ######################################################

                real_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus)

                real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out

                noise.data.normal_(0, 1)
                inputs = (real_img_code, noise)
                _, fake_from_real_img, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                loss_img = criterionCycle(F.sigmoid(fake_from_real_img),
                                          F.sigmoid(real_imgs[sort_idx]))

                # loss_img = F.binary_cross_entropy_with_logits(fake_from_real_img.view(batch_size, -1),
                #                                               real_imgs.view(batch_size, -1))

                #######################################################
                # (2c) Decode z from fake imgs and calc cycle loss
                ######################################################

                fake_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus)

                fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out
                fake_img_feats = fake_img_feats.transpose(0, 1)

                loss_cd = 0
                cd_dec_inp = Variable(
                    torch.LongTensor([self.txt_dico.SOS_TOKEN] *
                                     self.batch_size))
                cd_dec_inp = cd_dec_inp.cuda() if cfg.CUDA else cd_dec_inp

                cd_dec_hidden = fake_img_code.unsqueeze(0)

                max_target_length = inds.size(0)

                for t in range(max_target_length):

                    cd_dec_out, cd_dec_hidden, cd_dec_attn = decoder(
                        cd_dec_inp, cd_dec_hidden, fake_img_feats)

                    loss_cd = loss_cd + F.cross_entropy(
                        cd_dec_out,
                        inds[:, sort_idx][t],
                        ignore_index=self.txt_dico.PAD_TOKEN)
                    cd_dec_inp = inds[:, sort_idx][t]

                loss_cd = loss_cd / lengths.float().sum()

                loss_dc = criterionCycle(fake_imgs, real_imgs[sort_idx])

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                enc_disc.zero_grad()

                errD = 0

                errD_im, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                     real_labels, fake_labels,
                                                     real_txt_mu, self.gpus)

                # updating discriminator for encoding
                txt_enc_labels = Variable(
                    torch.FloatTensor(batch_size).fill_(0))
                img_enc_labels = Variable(
                    torch.FloatTensor(batch_size).fill_(1))
                if cfg.CUDA:
                    txt_enc_labels = txt_enc_labels.cuda()
                    img_enc_labels = img_enc_labels.cuda()

                disc_real_txt_emb = encoder_hidden[0].detach()
                disc_real_img_emb = real_img_emb.detach()

                pred_txt = enc_disc(disc_real_txt_emb)
                pred_img = enc_disc(disc_real_img_emb)

                enc_disc_loss_txt = F.binary_cross_entropy_with_logits(
                    pred_txt.squeeze(), txt_enc_labels)
                enc_disc_loss_img = F.binary_cross_entropy_with_logits(
                    pred_img.squeeze(), img_enc_labels)

                errD = errD + errD_im + enc_disc_loss_txt + enc_disc_loss_img

                # check NaN
                if (errD != errD).data.any():
                    print("NaN detected (discriminator)")
                    pdb.set_trace()
                    exit()

                errD.backward()

                optimizerD.step()
                enc_disc_optimizer.step()

                ############################
                # (2) Update G network
                ###########################
                encoder.zero_grad()
                decoder.zero_grad()
                netG.zero_grad()
                image_encoder.zero_grad()

                errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                              real_txt_mu, self.gpus)

                img_kl_loss = KL_loss(real_img_mu, real_img_logvar)
                txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar)
                #f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar)

                kl_loss = img_kl_loss + txt_kl_loss  #+ f_img_kl_loss

                #_, disc_hidden_g = encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None)
                #dg_mu, dg_logvar = nn.parallel.data_parallel(image_encoder, (real_imgs), self.gpus)
                #disc_img_g = torch.cat((dg_mu.unsqueeze(0), dg_logvar.unsqueeze(0)))

                pred_txt_g = enc_disc(encoder_hidden[0])
                pred_img_g = enc_disc(real_img_emb)

                enc_fake_loss_txt = F.binary_cross_entropy_with_logits(
                    pred_img_g.squeeze(), txt_enc_labels)
                enc_fake_loss_img = F.binary_cross_entropy_with_logits(
                    pred_txt_g.squeeze(), img_enc_labels)

                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + loss_cd + loss_dc + loss_img + loss_auto + enc_fake_loss_txt + enc_fake_loss_img

                # check NaN
                if (errG_total != errG_total).data.any():
                    print("NaN detected (generator)")
                    pdb.set_trace()
                    exit()

                errG_total.backward()

                optimizerG.step()
                image_enc_optimizer.step()
                enc_optimizer.step()
                dec_optimizer.step()

                count = count + 1
                if i % 100 == 0:
                    #                     summary_D = summary.scalar('D_loss', errD.data[0])
                    #                     summary_D_r = summary.scalar('D_loss_real', errD_real)
                    #                     #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    #                     summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    #                     summary_G = summary.scalar('G_loss', errG.data[0])
                    #                     #summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    #                     self.summary_writer.add_summary(summary_D, count)
                    #                     self.summary_writer.add_summary(summary_D_r, count)
                    #                     #self.summary_writer.add_summary(summary_D_w, count)
                    #                     self.summary_writer.add_summary(summary_D_f, count)
                    #                     self.summary_writer.add_summary(summary_G, count)
                    #                     #self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (real_txt_code, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)

                    self.vis.images(normalize(
                        real_imgs[sort_idx].data.cpu().numpy()),
                                    win=self.vis_win1)
                    self.vis.images(normalize(fake_imgs.data.cpu().numpy()),
                                    win=self.vis_win2)
                    self.vis.text("\n*".join(captions), win=self.vis_txt1)
                    if (len(pred_cap)):
                        self.vis.images(normalize(
                            fake_from_fake_img.data.cpu().numpy()),
                                        win=self.vis_win3)

            end_t = time.time()

            prefix = "E%d/%s, %.1fs" % (epoch, time.strftime('D%d %X'),
                                        (end_t - start_t))
            gen_str = "G_all: %.3f Cy_T: %.3f AE_T: %.3f AE_I %.3f KL_T %.3f KL_I %.3f" % (
                errG_total.data[0], loss_cd.data[0], loss_auto.data[0],
                loss_img.data[0], txt_kl_loss.data[0], img_kl_loss.data[0])

            dis_str = "D_all: %.3f D_I: %.3f D_zT: %.3f D_zI: %.3f" % (
                errD.data[0], errD_im.data[0], enc_disc_loss_txt.data[0],
                enc_disc_loss_img.data[0])

            print("%s %s, %s" % (prefix, gen_str, dis_str))

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, encoder, decoder, image_encoder, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, encoder, decoder, image_encoder, self.max_epoch,
                   self.model_dir)
        #
        self.summary_writer.close()
Exemple #10
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    self.summary_writer.add_scalars(main_tag="loss", tag_scalar_dict={
                        'D_loss':errD.cpu().item(),
                        'G_loss':errG_total.cpu().item()
                    }, global_step=count)
                    self.summary_writer.add_scalars(main_tag="D_loss", tag_scalar_dict={
                        "D_loss_real":errD_real,
                        "D_loss_wrong":errD_wrong,
                        "D_loss_fake":errD_fake
                    }, global_step=count)
                    self.summary_writer.add_scalars(main_tag="G_loss", tag_scalar_dict={
                        "G_loss":errG.cpu().item(),
                        "KL_loss":kl_loss.cpu().item()
                    }, global_step=count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    self.summary_writer.add_image(tag="fake_image", 
                        img_tensor=vutils.make_grid(fake_imgs, normalize=True, range=(-1,1)),
                        global_step=count
                    )
                    self.summary_writer.add_image(tag="real_image", 
                        img_tensor=vutils.make_grid(real_img_cpu, normalize=True, range=(-1,1)),
                        global_step=count
                    )
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.cpu().item(), errG.cpu().item(), kl_loss.cpu().item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
Exemple #11
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM  # 100
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        detectron = Detectron()
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
            #print('check 0')
            for i, data in enumerate(data_loader):
                ######################################################
                # (1) Prepare training data
                ######################################################
                #print('check 1')
                real_img_cpu, txt_embedding, caption = data
                caption = np.moveaxis(np.array(caption), 1, 0)

                #print('check 2')

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                #print('check 3')
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################

                #print(real_imgs.size())

                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)

                fake_img = fake_imgs.cpu().detach().numpy()
                #print(fake_img.shape)

                det_obj_list = detectron.get_labels(fake_img)

                fake_l = Variable(get_ohe(det_obj_list)).cuda()
                real_l = Variable(get_ohe(caption)).cuda()

                det_loss = nn.SmoothL1Loss()(fake_l, real_l)
                errG_total = det_loss + errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    summary_DET = summary.scalar('det_loss', det_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_DET, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        if cfg.TRAIN.ADAM:
            optimizerD = \
                optim.Adam(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
            optimizerG = optim.Adam(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR,
                                    betas=(0.5, 0.999))
        else:
            optimizerD = \
                optim.RMSprop(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR)
            optimizerG = \
                optim.RMSprop(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR)

        cnn = models.vgg19(pretrained=True).features
        cnn = nn.Sequential(*list(cnn.children())[0:28])
        gram = GramMatrix()
        if cfg.CUDA:
            cnn.cuda()
            gram.cuda()
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                else:
                    _, fake_imgs, mu, logvar = netG(txt_embedding, noise)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus, cfg.CUDA)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus, cfg.CUDA)
                kl_loss = KL_loss(mu, logvar)
                pixel_loss = PIXEL_loss(real_imgs, fake_imgs)
                if cfg.CUDA:
                    fake_features = nn.parallel.data_parallel(
                        cnn, fake_imgs.detach(), self.gpus)
                    real_features = nn.parallel.data_parallel(
                        cnn, real_imgs.detach(), self.gpus)
                else:
                    fake_features = cnn(fake_imgs)
                    real_features = cnn(real_imgs)
                active_loss = ACT_loss(fake_features, real_features)
                text_loss = TEXT_loss(gram, fake_features, real_features,
                                      cfg.TRAIN.COEFF.TEXT)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + \
                                pixel_loss * cfg.TRAIN.COEFF.PIX + \
                                active_loss * cfg.TRAIN.COEFF.ACT +\
                                text_loss
                errG_total.backward()
                optimizerG.step()
                count = count + 1
                if i % 100 == 0:

                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    summary_Pix = summary.scalar('Pixel_loss',
                                                 pixel_loss.data[0])
                    summary_Act = summary.scalar('Act_loss',
                                                 active_loss.data[0])
                    summary_Text = summary.scalar('Text_loss',
                                                  text_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_Pix, count)
                    self.summary_writer.add_summary(summary_Act, count)
                    self.summary_writer.add_summary(summary_Text, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    if cfg.CUDA:
                        lr_fake, fake, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                    else:
                        lr_fake, fake, _, _ = netG(txt_embedding, fixed_noise)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_Pixel: %.4f
                                     Loss_Activ: %.4f Loss_Text: %.4f
                                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                                     Total Time: %.2fsec
                                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.data[0],
                 errG.data[0], kl_loss.data[0], pixel_loss.data[0],
                 active_loss.data[0], text_loss.data[0], errD_real, errD_wrong,
                 errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemple #13
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        with torch.no_grad():
            #Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     #volatile=True)
            fixed_noise = \
                torch.FloatTensor(batch_size, nz).normal_(0, 1)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            #print('dataLoader, line 156 trainer.py...........')
            #print(data_loader)
            num_batches = len(data_loader)
            print('Number of batches: ' + str(len(data_loader)))
            for i, data in enumerate(data_loader, 0):

                print('Epoch number: ' + str(epoch) + '\tBatches: ' + str(i) + '/' + str(num_batches), end='\r')
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                #print(txt_embedding.shape)  #(Batch_size,1024)
                #exit(0)
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()
                #print('train line 170')
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                #print('Fake images generated shape = ' + str(fake_imgs.shape))
                #print('Shape of fake image: ' + str(fake_imgs.shape))  [Batch_size, Channels(3), N, N]
                #print('Fake images: ')
                #Display one image
                ### Check this line! How to display image?? ##############
                #plt.imshow(fake_imgs[0].permute(1,2,0).cpu().detach().numpy())
                #exit(0)

                ################################################

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                #print('train line 186')
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()
                #print('train line 203')


                count = count + 1
                if i % 100 == 0:

                    """
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    """
                    ## My lines
                    summary_D = summary.scalar('D_loss', errD.data)
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data)
                    summary_KL = summary.scalar('KL_loss', kl_loss.data)
                    #### End of my lines

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
                    del inputs
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.data, errG.data, kl_loss.data,
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            #    % (epoch, self.max_epoch, i, len(data_loader),
            #       errD.data[0], errG.data[0], kl_loss.data[0],
            #       errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            print('################EPOCH COMPLETED###########')
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)

        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemple #14
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####
        startpoint = -1
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict["optimD"])
            optimizerG.load_state_dict(state_dict["optimG"])
            startpoint = state_dict["epoch"]
            print(startpoint)
            print('Load Optim and optimizers as : ', cfg.NET_G)
        ####

        count = 0
        drive_count = 0
        for epoch in range(startpoint + 1, self.max_epoch):
            print('epoch : ', epoch, ' drive_count : ', drive_count)
            epoch_start_time = time.time()
            print(epoch)
            start_t = time.time()
            start_t500 = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            time_to_i = time.time()
            for i, data in enumerate(data_loader, 0):
                # if i >= 3360 :
                #     print ('Last Batches : ' , i)
                # if i < 10 :
                #     print ('first Batches : ' , i)
                # if i == 0 :
                #     print ('Startig! Batch ',i,'from total of 2070' )
                # if i % 10 == 0 and i!=0:
                #     end_t500 = time.time()
                #     print ('Batch Number : ' , i ,' |||||  Toatal Time : ' , (end_t500 - start_t500))
                #     start_t500 = time.time()
                ######################################################
                # (1) Prepare training data
                # if i < 10 :
                #     print (" (1) Prepare training data for batch : " , i)
                ######################################################
                #print ("Prepare training data for batch : " , i)
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                if cfg.CUDA:
                    label_one_hot = torch.cuda.FloatTensor(
                        noise.shape[0], max_objects, 81).fill_(0)
                else:
                    label_one_hot = torch.FloatTensor(noise.shape[0],
                                                      max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # # (2) Generate fake images
                # if i < 10 :
                #     print ("(2)Generate fake images")
                ######################################################

                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                else:
                    print('Hiiiiiiiiiiii')
                    _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise,
                                                       transf_matrices_inv,
                                                       label_one_hot)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # # (3) Update D network
                # if i < 10 :
                #     print("(3) Update D network")
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # # (4) Update G network
                # if i < 10 :
                #     print ("(4) Update G network")
                ###########################
                netG.zero_grad()
                # if i < 10 :
                #     print ("netG.zero_grad")
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    # if i < 10 :
                    #     print ("cgf.STAGE = " , cfg.STAGE)
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                    # if i < 10 :
                    #     print("errG : ",errG)
                kl_loss = KL_loss(mu, logvar)
                # if i < 10 :
                #     print ("kl_loss = " , kl_loss)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                # if i < 10 :
                #     print (" errG_total = " , errG_total )
                errG_total.backward()
                # if i < 10 :
                #     print ("errG_total.backward() ")
                optimizerG.step()
                # if i < 10 :
                #     print ("optimizerG.step() " )

                #print (" i % 500 == 0 :  " , i % 500 == 0 )
                end_t = time.time()
                #print ("batch time : " , (end_t - start_t))
                if i % 500 == 0:
                    #print (" i % 500 == 0" , i % 500 == 0 )
                    count += 1
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    print('epoch     :  ', epoch)
                    print('count     :  ', count)
                    print('  i       :  ', i)
                    print('Time to i : ', time.time() - time_to_i)
                    time_to_i = time.time()
                    print('D_loss : ', errD.item())
                    print('D_loss_real : ', errD_real)
                    print('D_loss_wrong : ', errD_wrong)
                    print('D_loss_fake : ', errD_fake)
                    print('G_loss : ', errG.item())
                    print('KL_loss : ', kl_loss.item())
                    print('generator_lr : ', generator_lr)
                    print('discriminator_lr : ', discriminator_lr)
                    print('lr_decay_step : ', lr_decay_step)

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)

                        if cfg.CUDA:
                            lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                                netG, inputs, self.gpus)
                        else:
                            lr_fake, fake, _, _, _ = netG(
                                txt_embedding, noise, transf_matrices_inv,
                                label_one_hot)

                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
                if i % 100 == 0:
                    drive_count += 1
                    self.drive_summary_writer.add_summary(
                        summary_D, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_r, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_w, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_f, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_G, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_KL, drive_count)

            #print (" with torch.no_grad(): "  )
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 )
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                    #print (" inputs " , inputs )
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                #print (" lr_fake, fake " , lr_fake, fake )
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , )

                #print (" lr_fake is not None: " , lr_fake is not None )
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
                    #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " )
                    #end_t = time.time()
                    #print ("batch time : " , (end_t - start_t))
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)

            print("keyTime |||||||||||||||||||||||||||||||")
            print("epoch_time : ", time.time() - epoch_start_time)
            print("KeyTime |||||||||||||||||||||||||||||||")

        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()