Beispiel #1
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        for myit in range(len(self.netsD)):
            self.netsD[myit].zero_grad()

        errG_total = 0
        flag = count % 100
        batch_size = self.real_fimgs[0].size(0)
        [
            criterion_one,
            criterion_class,
            c_code, p_code
        ] = self.criterion_one, self.criterion_class, self.c_code, self.p_code

        for i in range(self.num_Ds):

            outputs = self.netsD[i](self.fake_imgs[i])

            # real/fake loss for background (0) and child (2) stage
            if i == 0 or i == 2:
                real_labels = torch.ones_like(outputs[1])
                errG = criterion_one(outputs[1], real_labels)
                if i == 0:
                    errG = errG * cfg.TRAIN.BG_LOSS_WT
                    # Background/Foreground classification loss for
                    # the fake background image (on patch level)
                    errG_classi = criterion_one(outputs[0], real_labels)
                    errG = errG + errG_classi
                errG_total = errG_total + errG

            if i == 1:  # Mutual information loss for the parent stage (1)
                pred_p = self.netsD[i](self.fg_mk[i-1])
                errG_info = criterion_class(
                    pred_p[0], torch.nonzero(p_code.long())[:, 1])
            elif i == 2:  # Mutual information loss for the child stage (2)
                pred_c = self.netsD[i](self.fg_mk[i-1])
                errG_info = criterion_class(
                    pred_c[0], torch.nonzero(c_code.long())[:, 1])

            if(i > 0):
                errG_total = errG_total + errG_info

            
            if flag == 0:
                if i > 0:
                    summary_D_class = summary.scalar(
                        'Information_loss_%d' % i, errG_info.item())
                    self.summary_writer.add_summary(summary_D_class, count)

                if i == 0 or i == 2:
                    summary_D = summary.scalar('G_loss%d' % i, errG.item())
                    self.summary_writer.add_summary(summary_D, count)
        errG_total.backward()
        for myit in range(len(self.netsD)):
            self.optimizerG[myit].step()
        return errG_total
Beispiel #2
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                    outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.item())
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.item())
                    self.summary_writer.add_summary(sum_mu, global_step=count)

                    sum_cov = summary.scalar('G_like_cov2', like_cov2.item())
                    self.summary_writer.add_summary(sum_cov, global_step=count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.item())
                    self.summary_writer.add_summary(sum_mu, count)

                    sum_cov = summary.scalar('G_like_cov1', like_cov1.item())
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total
Beispiel #3
0
def main(exercise: str = "Aufgabe-2"):

    with open("./config.{}.json".format(exercise), mode="r") as f:
        args = json.load(f,
                         object_hook=lambda d: namedtuple("X", d.keys())
                         (*d.values()))

    environment = gym.make(args.env)

    input_shape = environment.observation_space.shape
    num_actions = environment.action_space.n

    output_directory = "./tmp/{}/{}".format(exercise, datetime.datetime.now())

    writer = FileWriter(output_directory)

    agent = make_agent(args, input_shape, num_actions, output_directory)

    rewards = []
    for episode in range(args.episodes):
        episode_rewards = run_episode(
            environment,
            agent,
            render=episode % args.render_episode_interval == 0,
            max_length=args.max_episode_length,
        )
        rewards.append(episode_rewards)
        if episode % args.training_interval == 0:
            for _ in range(args.training_interval):
                loss = agent.train()

            if loss and episode % (args.training_interval * 10) == 0:
                mean_rewards = np.mean(rewards)
                std_rewards = np.std(rewards)
                writer.add_summary(summary=summary.scalar("dqn/loss", loss),
                                   global_step=episode)
                writer.add_summary(
                    summary=summary.scalar("rewards/mean", mean_rewards),
                    global_step=episode,
                )
                writer.add_summary(
                    summary=summary.scalar("rewards/standard deviation",
                                           std_rewards),
                    global_step=episode,
                )
                writer.add_summary(
                    summary=summary.scalar("dqn/epsilon",
                                           agent.exploration_strategy.epsilon),
                    global_step=episode,
                )
                print("Episode {}\tMean rewards {:f}\tLoss {:f}\tEpsilon {:f}".
                      format(episode, mean_rewards, loss,
                             agent.exploration_strategy.epsilon))
                rewards.clear()
Beispiel #4
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            netD = self.netsD[i]
            outputs = netD(self.fake_imgs[i])
            errG = criterion(outputs[0], real_labels)
            # errG = self.stage_coeff[i] * errG
            errG_total = errG_total + errG
            if flag == 0:
                summary_G = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_G, count)

        # Compute color preserve losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1

            if flag == 0:
                sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                self.summary_writer.add_summary(sum_mu, count)
                sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                self.summary_writer.add_summary(sum_cov, count)
                if self.num_Ds > 2:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        errG_total.backward()
        self.optimizerG.step()
        return errG_total
Beispiel #5
0
def write_loss(iterations, max_iterations, trainer, train_writer,
               elapsed_time):
    print("Iteration: %08d/%08d %.2fs" %
          (iterations + 1, max_iterations, elapsed_time))
    members = [attr for attr in dir(trainer) \
               if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'loss' in attr]
    for m in members:
        train_writer.add_summary(summary.scalar(m, getattr(trainer, m)),
                                 iterations + 1)
    members = [attr for attr in dir(trainer) \
               if not callable(getattr(trainer, attr)) and not attr.startswith("__") and 'acc' in attr]
    for m in members:
        train_writer.add_summary(summary.scalar(m, getattr(trainer, m)),
                                 iterations + 1)
Beispiel #6
0
    def train_Gnet(self, count):
        errG_total = 0
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion_1, mu = self.criterion_1, self.mu_theta1

        params = [
            self.shape1, self.scale1, self.Phi1, self.theta1, self.txtbow
        ]
        criterion, optEnG, netG = self.criterion_2, self.optimizerEnG, self.netG
        loss, theta1_KL, Likelihood, p1, p2, p3, shape1, scale1 = criterion(
            params)

        real_labels = self.real_labels[:batch_size]
        for i in xrange(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion_1(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                             criterion_1(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        errG_total += loss
        optEnG.zero_grad()

        errG_total.backward()

        optEnG.step()
        return float(errG_total
                     ), theta1_KL, Likelihood, loss, p1, p2, p3, shape1, scale1
Beispiel #7
0
    def train_Gnet(self, idx, count):
        optG = self.optimizersG[idx]
        optG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion, c_code = self.criterion, self.c_code[idx]
        real_labels = self.real_labels[:batch_size]
        for i in xrange(len(self.netsG)):
            outputs = self.netsD[idx * 3 + i](self.fake_imgs[idx * 3 + i],
                                              c_code)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        errG_total = errG_total
        errG_total.backward()
        optG.step()
        return float(errG_total)
Beispiel #8
0
    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        #
        netD.zero_grad()
        #
        real_logits = netD(real_imgs)
        fake_logits = netD(fake_imgs.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        #
        errD = errD_real + errD_fake
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD
Beispiel #9
0
    def train_MDnet(self, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        real_imgs = self.real_imgs[-1]
        wrong_imgs = self.wrong_imgs[-1]
        fake_imgs = self.fake_imgs[-1]
        similar_imgs = self.similar_imgs
        #
        netMD = self.netMD
        optMD = self.optimizerMD
        netMD.zero_grad()

        same_labels = self.same_labels[:batch_size]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.real_labels[:batch_size]
        wrong_labels = self.wrong_labels[:batch_size]

        real_feat = self.image_cnn(real_imgs.detach())
        real_feat = self.image_encoder(real_feat.detach())
        similar_feat = self.image_cnn(similar_imgs.detach())
        similar_feat = self.image_encoder(similar_feat.detach())
        fake_feat = self.image_cnn(fake_imgs.detach())
        fake_feat = self.image_encoder(fake_feat.detach())
        wrong_feat = self.image_cnn(wrong_imgs.detach())
        wrong_feat = self.image_encoder(wrong_feat.detach())

        same_logits = netMD(real_feat, real_feat)
        real_logits2 = netMD(real_feat, similar_feat)
        fake_logits2 = netMD(real_feat, fake_feat.detach())
        wrong_logits2 = netMD(real_feat, wrong_feat)

        errMD_si = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            real_logits2, real_labels.long())
        errMD_sa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            same_logits, same_labels.long())
        errMD_fa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            fake_logits2, fake_labels.long())
        errMD_wr = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            wrong_logits2, wrong_labels.long())
        if cfg.DATASET_NAME == 'birds' or cfg.DATASET_NAME == 'flowers':
            errMD = errMD_si + errMD_sa + errMD_fa + errMD_wr
        else:
            errMD = errMD_si + errMD_fa + errMD_wr

        # backward
        errMD.backward()
        optMD.step()
        # log
        if flag == 0:
            summary_MD = summary.scalar('MD_loss', errMD.item())
            self.summary_writer.add_summary(summary_MD, count)
        return errMD
Beispiel #10
0
    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD
Beispiel #11
0
 def train_Enet(self, count):
     errEn_total = 0
     flag = count % 100
     params = [
         self.shape_1, self.scale_1, self.shape_2, self.scale_2,
         self.shape_3, self.scale_3, self.Phi[0], self.theta_1, self.Phi[1],
         self.theta_2, self.Phi[2], self.theta_3, self.txtbow
     ]
     criterion, optEn, netEn = self.vae, self.optimizerEn, self.netEn
     loss, theta3_KL, theta2_KL, theta1_KL, Likelihood, Lowerbound = criterion(
         params)
     errEn_total = errEn_total + loss
     netEn.zero_grad()
     # backward
     errEn_total.backward()
     # update parameters
     optEn.step()
     if flag == 0:
         summary_LS = summary.scalar('En_loss', float(loss.data.item()))
         summary_LB = summary.scalar('En_lowerbound',
                                     float(Lowerbound.data.item()))
         summary_LL = summary.scalar('En_likelihood',
                                     float(Likelihood.data.item()))
         summary_KL1 = summary.scalar('En_kl1',
                                      float(theta1_KL.data.item()))
         summary_KL2 = summary.scalar('En_kl2',
                                      float(theta2_KL.data.item()))
         summary_KL3 = summary.scalar('En_kl3',
                                      float(theta3_KL.data.item()))
         self.summary_writer.add_summary(summary_LS, count)
         self.summary_writer.add_summary(summary_LB, count)
         self.summary_writer.add_summary(summary_LL, count)
         self.summary_writer.add_summary(summary_KL1, count)
         self.summary_writer.add_summary(summary_KL2, count)
         self.summary_writer.add_summary(summary_KL3, count)
     return theta1_KL, theta2_KL, theta3_KL, Likelihood, Lowerbound, loss
Beispiel #12
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = torch.FloatTensor([1])
        fake_labels = real_labels * -1
        wrong_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        # real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        # fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels, wrong_labels = real_labels.cuda(
            ), fake_labels.cuda(), wrong_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = optim.RMSprop(netD.parameters(), lr=discriminator_lr)
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.RMSprop(netG_para, lr=generator_lr)
        # optimizerD = \
        #     optim.Adam(netD.parameters(),
        #                lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        # netG_para = []
        # for p in netG.parameters():
        #     if p.requires_grad:
        #         netG_para.append(p)
        # optimizerG = optim.Adam(netG_para,
        #                         lr=cfg.TRAIN.GENERATOR_LR,
        #                         betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                for p in netD.parameters():
                    p.data.clamp_(-0.01, 0.01)
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                real_labels, fake_labels,wrong_labels,
                                                mu, self.gpus)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward(retain_graph=True)
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Beispiel #13
0
    def train(self):
        self.netG, self.netsD, self.netIMG, self.inception_model,\
        self.num_Ds, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizersD, self.optimizerEnG = \
            define_optimizers(self.netG, self.netsD, self.netIMG)

        self.criterion_1 = nn.BCELoss()
        self.criterion_2 = myLoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        # Prepare PHI
        real_min = np.float64(2.2e-308)
        Phi1 = 0.2 + 0.8 * np.float64(np.random.rand(1000, 256))
        Phi1 = Phi1 / np.maximum(real_min, Phi1.sum(0))
        if cfg.CUDA:
            self.Phi1 = Variable(Phi1).cuda()
            self.criterion_1.cuda()
            self.criterion_2.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // self.num_batches

        batch_length = self.num_batches
        self.NDot = 0
        self.ForgetRate = np.power(
            (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                             cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.eta = 0.1
        epsit = np.power(
            (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                              cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.epsit = 1 * epsit / epsit[0]

        num_total_samples = batch_length * self.batch_size

        start_t = time.time()

        for epoch in xrange(start_epoch, self.max_epoch):
            LL = 0
            KL = 0
            LS = 0
            p1 = 0
            p2 = 0
            p3 = 0
            DL = 0
            GL = 0
            shape = []
            scale = []
            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data(
                    data)

                #######################################################
                # (1) Get conv hidden units
                ######################################################
                _, self.flat = self.inception_model(self.real_tgpu[-1])

                #######################################################
                # (2) Get shape, scale and sample of theta
                ######################################################
                self.theta1, self.shape1, self.scale1 = self.netIMG(self.flat)

                shape1 = self.shape1.detach().cpu().numpy()
                scale1 = self.scale1.detach().cpu().numpy()
                mu_theta1 = scale1 * ss.gamma(1 + 1 / shape1)
                self.mu_theta1 = torch.tensor(mu_theta1, dtype=torch.float32)
                #######################################################
                # (3) Generate fake images
                ######################################################
                self.fake_imgs, _ = self.netG(self.mu_theta1.detach())

                #######################################################
                # (4) Update D network
                ######################################################
                errD_total = 0
                for i in xrange(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (5) Update G network (or En network): maximize log(D(G(z)))
                ######################################################
                errG_total, self.KL1, self.LL, self.LS, self.p1, self.p2, self.p3, shape1, scale1 = self.train_Gnet(
                    count)
                LL += self.LL
                KL += self.KL1
                LS += self.LS
                p1 += self.p1
                p2 += self.p2
                p3 += self.p3
                shape.append(shape1)
                scale.append(scale1)

                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                DL += errD_total
                GL += errG_total

                #######################################################
                # (6) Update Phi
                #######################################################
                input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()),
                                     order='C').astype('double')
                Phi1 = np.array(self.Phi1.cpu().numpy(),
                                order='C').astype('double')
                Theta1 = np.array(np.transpose(self.theta1.cpu().numpy()),
                                  order='C').astype('double')
                phi1, self.NDot = self.updatePhi(input_txt, Phi1, Theta1,
                                                 int(batch_length), count,
                                                 self.NDot)
                self.Phi1 = torch.tensor(phi1, dtype=torch.float32).cuda()

                # for inception score
                pred, _ = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total)
                    summary_G = summary.scalar('G_loss', errG_total)
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netIMG, self.netG, avg_param_G, self.netsD,
                               epoch, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)

                    self.fake_imgs, _ = \
                        self.netG(self.theta1)

                    save_img_results(self.img_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []
                count += 1

            end_t = time.time()
            LS = LS / num_total_samples
            LL = LL / num_total_samples
            KL = KL / num_total_samples
            DL = DL / num_total_samples
            GL = GL / num_total_samples
            print(
                'Epoch: %d/%d,   Time elapsed: %.4fs\n'
                '* Batch Train Loss: %.6f          (LL: %.6f, KL: %.6f, Loss_D:'
                '%.2f Loss_G: %.2f)\n' %
                (epoch, self.max_epoch, end_t - start_t, LS, LL, KL, DL, GL))
            start_t = time.time()
            if epoch % 50 == 0:
                save_model(self.netIMG, self.netG, avg_param_G, self.netsD,
                           epoch, count, self.model_dir)
        # save the model at the last updating
        save_model(self.netIMG, self.netG, avg_param_G, self.netsD, epoch,
                   count, self.model_dir)
        self.summary_writer.close()
Beispiel #14
0
    def train_Dnet(self, idx, count):
        if idx == 0 or idx == 2:  # Discriminator is only trained in background and child stage. (NOT in parent stage)
            flag = count % 100
            batch_size = self.real_fimgs[0].size(0)
            criterion, criterion_one = self.criterion, self.criterion_one

            netD, optD = self.netsD[idx], self.optimizersD[idx]
            if idx == 0:
                real_imgs = self.real_fimgs[0]

            elif idx == 2:
                real_imgs = self.real_cimgs[0]

            fake_imgs = self.fake_imgs[idx]
            netD.zero_grad()
            real_logits = netD(real_imgs)

            if idx == 2:
                fake_labels = torch.zeros_like(real_logits[1])
                real_labels = torch.ones_like(real_logits[1])
            elif idx == 0:

                fake_labels = torch.zeros_like(real_logits[1])
                ext, output = real_logits
                weights_real = torch.ones_like(output)
                real_labels = torch.ones_like(output)

                for i in range(batch_size):
                    x1 = self.warped_bbox[0][i]
                    x2 = self.warped_bbox[2][i]
                    y1 = self.warped_bbox[1][i]
                    y2 = self.warped_bbox[3][i]

                    a1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((x1 - self.recp_field) / self.patch_stride))
                    a2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - x2) /
                                    self.patch_stride)) + 1
                    b1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((y1 - self.recp_field) / self.patch_stride))
                    b2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - y2) /
                                    self.patch_stride)) + 1

                    if (x1 != x2 and y1 != y2):
                        weights_real[
                            i, :,
                            a1.type(torch.int):a2.type(torch.int),
                            b1.type(torch.int):b2.type(torch.int)] = 0.0

                norm_fact_real = weights_real.sum()
                norm_fact_fake = weights_real.shape[0] * weights_real.shape[
                    1] * weights_real.shape[2] * weights_real.shape[3]
                real_logits = ext, output

            fake_logits = netD(fake_imgs.detach())

            if idx == 0:  # Background stage

                errD_real_uncond = criterion(
                    real_logits[1], real_labels
                )  # Real/Fake loss for 'real background' (on patch level)
                errD_real_uncond = torch.mul(
                    errD_real_uncond, weights_real
                )  # Masking output units which correspond to receptive fields which lie within the boundin box
                errD_real_uncond = errD_real_uncond.mean()

                errD_real_uncond_classi = criterion(
                    real_logits[0],
                    weights_real)  # Background/foreground classification loss
                errD_real_uncond_classi = errD_real_uncond_classi.mean()

                errD_fake_uncond = criterion(
                    fake_logits[1], fake_labels
                )  # Real/Fake loss for 'fake background' (on patch level)
                errD_fake_uncond = errD_fake_uncond.mean()

                if (
                        norm_fact_real > 0
                ):  # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
                    errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /
                                                    (norm_fact_real * 1.0))
                else:
                    errD_real = errD_real_uncond

                errD_fake = errD_fake_uncond
                errD = ((errD_real + errD_fake) *
                        cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi

            if idx == 2:

                errD_real = criterion_one(
                    real_logits[1],
                    real_labels)  # Real/Fake loss for the real image
                errD_fake = criterion_one(
                    fake_logits[1],
                    fake_labels)  # Real/Fake loss for the fake image
                errD = errD_real + errD_fake

            if (idx == 0 or idx == 2):
                errD.backward()
                optD.step()

            if (flag == 0):
                summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
                self.summary_writer.add_summary(summary_D, count)
                summary_D_real = summary.scalar('D_loss_real_%d' % idx,
                                                errD_real.data[0])
                self.summary_writer.add_summary(summary_D_real, count)
                summary_D_fake = summary.scalar('D_loss_fake_%d' % idx,
                                                errD_fake.data[0])
                self.summary_writer.add_summary(summary_D_fake, count)

            return errD
Beispiel #15
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        with torch.no_grad():
            fixed_noise = \
                Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        log.info('Training stage-{}'.format(stage))
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(
                    tqdm(iter(data_loader),
                         leave=False,
                         total=len(data_loader)), 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_voxels_cpu = data['voxel_tensor']
                txt_embeddings = data['raw_embedding']
                real_voxels = Variable(real_voxels_cpu)
                txt_embeddings = Variable(txt_embeddings)
                if cfg.CUDA:
                    real_voxels = real_voxels.cuda()
                    txt_embeddings = txt_embeddings.cuda()

                #######################################################
                # (2) Generate fake voxels
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embeddings, noise)
                _, fake_voxels, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_voxels, fake_voxels,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_voxels, real_labels,
                                              mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the voxels result for each epoch
                    inputs = (txt_embeddings, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_voxels_results(real_voxels_cpu, fake, epoch,
                                        self.voxels_dir)
                    if lr_fake is not None:
                        save_voxels_results(None, lr_fake, epoch,
                                            self.voxels_dir)
            end_t = time.time()
            log.info(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)

        save_model(netG, netD, self.max_epoch, self.model_dir)

        self.summary_writer.close()
Beispiel #16
0
    def train(self):
        self.netEn, self.netsG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = []
        for i in xrange(len(self.netsG)):
            avg_param_G.append(copy_G_params(self.netsG[i]))

        self.optimizerEn, self.optimizersG, self.optimizersD = \
            define_optimizers(self.netEn, self.netsG, self.netsD)

        self.criterion = nn.BCELoss()
        self.vae = myLoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        predictions = []
        count = start_count
        start_epoch = start_count // self.num_batches

        batch_length = self.num_batches

        self.Phi = []
        self.eta = []
        K = [256, 128, 64]
        real_min = np.float64(2.2e-308)
        eta = np.ones(3) * 0.1
        for i in range(3):
            self.eta.append(eta[i])
            if i == 0:
                self.Phi.append(0.2 +
                                0.8 * np.float64(np.random.rand(1000, K[i])))
            else:
                self.Phi.append(0.2 + 0.8 *
                                np.float64(np.random.rand(K[i - 1], K[i])))
            self.Phi[i] = self.Phi[i] / np.maximum(real_min,
                                                   self.Phi[i].sum(0))

        self.NDot = [0] * 3
        self.Xt_to_t1 = [0] * 3
        self.WSZS = [0] * 3
        self.EWSZS = [0] * 3

        self.ForgetRate = np.power(
            (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                             cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        epsit = np.power(
            (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                              cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.epsit = 1 * epsit / epsit[0]

        num_total_samples = batch_length * self.batch_size

        if cfg.CUDA:
            for i in xrange(len(self.Phi)):
                self.Phi[i] = Variable(torch.from_numpy(
                    self.Phi[i]).float()).cuda()
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        for epoch in xrange(start_epoch, self.max_epoch):
            start_t = time.time()
            LL = 0
            KL1 = 0
            KL2 = 0
            KL3 = 0
            LS = 0
            DL = 0
            GL = 0
            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data(
                    data)

                #######################################################
                # (1) Get conv hidden units
                ######################################################
                _, self.flat = self.inception_model(self.real_tgpu[-1])

                self.theta_1, self.shape_1, self.scale_1, self.theta_2,\
                self.shape_2, self.scale_2, self.theta_3, self.shape_3,\
                self.scale_3 = self.netEn(self.flat)

                self.txt_embedding = []
                self.txt_embedding.append(self.theta_3.detach())
                self.txt_embedding.append(self.theta_2.detach())
                self.txt_embedding.append(self.theta_1.detach())
                #######################################################
                # (2) Generate fake images
                ######################################################
                tmp = []
                self.c_code = []
                x_embedding = None
                for it in xrange(len(self.netsG)):
                    fake_imgs, c_code, x_embedding = \
                        self.netsG[it](self.txt_embedding[it], x_embedding)
                    tmp.append(fake_imgs)
                    self.c_code.append(c_code)

                self.fake_imgs = []
                for it in xrange(len(tmp)):
                    for jt in xrange(len(tmp[it])):
                        self.fake_imgs.append(tmp[it][jt])

                #######################################################
                # (3) Update En network
                ######################################################
                self.KL1, self.KL2, self.KL3, self.LL, self.LB, self.LS = self.train_Enet(
                    count)
                LL += self.LL
                KL1 += self.KL1
                KL2 += self.KL2
                KL3 += self.KL3
                LS += self.LS

                if count % 100 == 0:
                    print(self.LS)
                    print(self.KL1)
                    print(self.KL2)
                    print(self.KL3)

                #######################################################
                # (4) Update Phi
                #######################################################
                input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()),
                                     order='C').astype('double')
                Phi = []
                theta = []
                self.theta = [self.theta_1, self.theta_2, self.theta_3]
                for i in xrange(len(self.Phi)):
                    Phi.append(
                        np.array(self.Phi[i].cpu().numpy(),
                                 order='C').astype('double'))
                    theta.append(
                        np.array(np.transpose(
                            self.theta[i].detach().cpu().numpy()),
                                 order='C').astype('double'))
                phi = self.updatePhi(input_txt, Phi, theta, int(batch_length),
                                     count)
                for i in xrange(len(phi)):
                    self.Phi[i] = torch.tensor(phi[i],
                                               dtype=torch.float32).cuda()

                #######################################################
                # (5) Update D network
                ######################################################
                errD_total = 0
                for i in xrange(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD
                DL += errD_total

                #######################################################
                # (6) Update G network: maximize log(D(G(z)))
                ######################################################
                errG_total = 0
                for i in xrange(len(self.netsG)):
                    errG = self.train_Gnet(i, count)
                    errG_total += errG
                    for p, avg_p in zip(self.netsG[i].parameters(),
                                        avg_param_G[i]):
                        avg_p.mul_(0.999).add_(0.001, p.data)
                GL += errG_total

                # for inception score
                if cfg.INCEPTION:
                    pred, _ = self.inception_model(self.fake_imgs[-1].detach())
                    predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total)
                    summary_G = summary.scalar('G_loss', errG_total)
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netEn, self.netsG, avg_param_G, self.netsD,
                               epoch, self.model_dir)
                    # Save images
                    backup_para = []
                    for i in xrange(len(self.netsG)):
                        backup_para.append(copy_G_params(self.netsG[i]))
                        load_params(self.netsG[i], avg_param_G[i])

                    x_embedding = None
                    self.fake_imgs = []
                    for it in xrange(len(self.netsG)):
                        fake_imgs, _, x_embedding = self.netsG[it](
                            self.txt_embedding[it], x_embedding)
                        self.fake_imgs.append(fake_imgs[-1])
                    save_img_results(self.img_tcpu, self.fake_imgs,
                                     len(self.netsG), count, self.image_dir)

                    for i in xrange(len(self.netsG)):
                        load_params(self.netsG[i], backup_para[i])

                    if cfg.INCEPTION:
                        # Compute inception score
                        if len(predictions) > 500:
                            predictions = np.concatenate(predictions, 0)
                            mean, std = compute_inception_score(
                                predictions, 10)
                            m_incep = summary.scalar('Inception_mean', mean)
                            self.summary_writer.add_summary(m_incep, count)

                            mean_nlpp, std_nlpp = \
                                negative_log_posterior_probability(predictions, 10)
                            m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                            self.summary_writer.add_summary(m_nlpp, count)

                            predictions = []

                count = count + 1

            end_t = time.time()
            LS = LS / num_total_samples
            LL = LL / num_total_samples
            KL1 = KL1 / num_total_samples
            KL2 = KL2 / num_total_samples
            KL3 = KL3 / num_total_samples
            DL = DL / num_total_samples
            GL = GL / num_total_samples
            print(
                'Epoch: %d/%d,   Time elapsed: %.4fs\n'
                '* Batch Train Loss: %.6f          (LL: %.6f, KL1: %.6f, KL2: %.6f,'
                'KL3: %.6f, Loss_D: %.2f Loss_G: %.2f)\n' %
                (epoch, self.max_epoch, end_t - start_t, LS, LL, KL1, KL2, KL3,
                 DL, GL))
        save_model(self.netEn, self.netsG, avg_param_G, self.netsD, epoch,
                   self.model_dir)
        self.summary_writer.close()
Beispiel #17
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        for myit in range(4):
            self.netsD[myit].zero_grad()

        errG_total = 0
        flag = count % 100
        batch_size = self.real_fimgs.size(0)
        criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code

        for i in range(3):
            if i == 0 or i == 2:  # real/fake loss for background (0) and child (2) stage
                if i == 0:
                    outputs = self.netsD[0](self.fake_imgs[0], self.alpha,
                                            self.aux_masks)
                    real_labels = torch.ones_like(outputs[1])
                    errG0 = criterion_one(outputs[1], real_labels)
                    errG0 = errG0 * cfg.TRAIN.BG_LOSS_WT_GLB

                    outputs = self.netsD[3](self.fake_imgs[0], self.alpha)
                    real_labels = torch.ones_like(outputs[1])
                    errG1 = criterion_one(outputs[1], real_labels)
                    errG1 = errG1 * cfg.TRAIN.BG_LOSS_WT_LCL

                    errG_classi = criterion_one(
                        outputs[0], real_labels
                    )  # Background/Foreground classification loss for the fake background image (on patch level)
                    errG_classi = errG_classi * cfg.TRAIN.BG_CLASSI_WT

                    errG = errG0 + errG1 + errG_classi
                    errG_total = errG_total + errG

                else:  # i = 2
                    outputs = self.netsD[2](self.fake_imgs[2], self.alpha)
                    real_labels = torch.ones_like(outputs[1])
                    errG = criterion_one(outputs[1], real_labels)
                    errG_total = errG_total + errG

            if i == 1:  # Mutual information loss for the parent stage (1)
                pred_p = self.netsD[i](self.fg_mk[i - 1], self.alpha)
                errG_info = criterion_class(pred_p[0],
                                            torch.nonzero(p_code.long())[:, 1])
            elif i == 2:  # Mutual information loss for the child stage (2)
                pred_c = self.netsD[i](self.fg_mk[i - 1], self.alpha)
                errG_info = criterion_class(pred_c[0],
                                            torch.nonzero(c_code.long())[:, 1])

            if i > 0:
                errG_total = errG_total + errG_info

            if flag == 0:
                if i > 0:
                    summary_D_class = summary.scalar('Information_loss_%d' % i,
                                                     errG_info.item())
                    self.summary_writer.add_summary(summary_D_class, count)

                if i == 0 or i == 2:
                    summary_D = summary.scalar('G_loss%d' % i, errG.item())
                    self.summary_writer.add_summary(summary_D, count)

        errG_total.backward()
        for myit in range(3):
            self.optimizerG[myit].step()
        return errG_total
def run(args):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, name='global_step', trainable=False)
        train_flag = tf.placeholder(tf.bool)
        keep_prob = tf.placeholder(tf.float32)
        feature_1 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_1')
        feature_2 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_2')
        feature_3 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_3')
        real_labels=tf.placeholder(tf.int32, [None, ], 'real_labels')

        ### generated predicate features ###
        bottle_z = ST_encoder(dim_G, feature_1, keep_prob, reuse=False, training=train_flag)
        reconstruction = ST_decoder(dim_D,1000, bottle_z, feature_2, keep_prob, reuse=False, training=train_flag)
        errL1 = tf.reduce_mean(tf.losses.absolute_difference(reconstruction, feature_3, reduction=tf.losses.Reduction.NONE))
        if args.ac_weight > 0:
            ac_loss = aux_classifier(reconstruction, real_labels, args.num_predicates, keep_prob, reuse=False, training=train_flag)
        else:
            ac_loss=tf.zeros(1,dtype=tf.dtypes.float32)
        #errL1 = tf.reduce_mean(tf.abs(reconstruction - feature_3))

        errD_fake = netD(256, reconstruction, n_layers=0, reuse=False)
        errD_real = netD(256, feature_3, n_layers=0, reuse=True)

        # cost functions
        errD = tf.reduce_mean(errD_fake) - tf.reduce_mean(errD_real)
        errG = -tf.reduce_mean(errD_fake)
        if args.ac_weight > 0:
            errG_total = errG + errL1 * args.L1_weight + args.ac_weight * ac_loss
        else:
            errG_total = errG + errL1 * args.L1_weight

        # gradient penalty
        epsilon = tf.random_uniform([], 0.0, 1.0)
        x_hat = feature_3 * (1 - epsilon) + epsilon * reconstruction
        d_hat = netD(256,x_hat, n_layers=0, reuse=True)
        gradients = tf.gradients(d_hat, x_hat)[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = 10 * tf.reduce_mean((slopes - 1.0) ** 2)
        errD_total = errD + gradient_penalty

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'Discriminator' in var.name]
        g_vars = [var for var in t_vars if 'Generator' in var.name]
        learning_rate = get_learning_rate(data_num, global_step)
        G_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errG_total,global_step,var_list=g_vars)
        D_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errD_total,global_step,var_list=d_vars)

        ops = {'D_train_op': D_train_op,
               'G_train_op': G_train_op,
               'feature_1': feature_1,
               'feature_2': feature_2,
               'feature_3': feature_3,
               'keep_prob': keep_prob,
               'real_labels': real_labels,
               'train_flag': train_flag,
               'errD': errD,
               'errG': errG,
               'errL1': errL1,
               'ac_loss': ac_loss,
               'reconstruction': reconstruction}

        saver = tf.train.Saver(max_to_keep=None)
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        ### make gpu memory grow according to needed ###
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess.run(init)
        summary_writer = FileWriter(log_path, graph=tf.get_default_graph())

        # tf.add_to_collection('train_op', train_op)
        tf.add_to_collection('G_train_op', G_train_op)
        tf.add_to_collection('D_train_op', D_train_op)

        start_epoch=0
        if args.training:
            # restore previous model if there is one
            ckpt = tf.train.get_checkpoint_state(model_pth)
            if ckpt and ckpt.model_checkpoint_path:
                print("Restoring previous model...")
                try:
                    start_epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) + 1
                    print(start_epoch)
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print("Model restored")
                except:
                    print("Could not restore model")
                    pass

            ########################################### training portion
            for epoch in range(start_epoch,args.max_epoch):
                start = time.time()
                train_loss_d, train_loss_g, train_loss_L1Loss,train_loss_acLoss = train_one_epoch(sess, input_data, ops, args)
                print('epoch:', epoch, 'D loss:', train_loss_d.avg, 'G_loss:', train_loss_g.avg, 'L1:',train_loss_L1Loss.avg,'AC:',train_loss_acLoss.avg, 'time:', time.time() - start)

                summary_D = summary.scalar('D_loss', train_loss_d.avg)
                summary_writer.add_summary(summary_D, epoch)
                summary_G = summary.scalar('G_loss', train_loss_g.avg)
                summary_writer.add_summary(summary_G, epoch)
                summary_G_L1 = summary.scalar('G_L1', train_loss_L1Loss.avg)
                summary_writer.add_summary(summary_G_L1, epoch)
                summary_AC = summary.scalar('G_AC', train_loss_acLoss.avg)
                summary_writer.add_summary(summary_AC, epoch)
                if (epoch + 1) % 10 == 0:
                    print('save model')
                    if not os.path.exists(model_pth):
                        os.makedirs(model_pth)
                    saver.save(sess, model_pth + 'checkpoint-' + str(epoch))
                    saver.export_meta_graph(model_pth + 'checkpoint-' + str(epoch) + '.meta')
        else:
            print('evaluation')
            ckpt = tf.train.get_checkpoint_state(model_pth)
            try:
                epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Model restored")
            except:
                print("Could not restore model")
                exit(0)
                pass

            ### generate whole data###
            if args.test_setting == 'wholedata':
                print('generate whole data:', epoch)
                generate_wholedata(sess, input_data, ops, epoch)
            ### generate lowshot vrd data###
            elif args.test_setting == 'lowshot':
                print('generate lowshot data:', epoch)
                generate_lowshot(sess, input_data, ops, args, epoch)
    input_data.close()
Beispiel #19
0
def log_history(hist, epoch):
    for ep in hist.epoch:
        for val in hist.history:
            writer.add_summary(
                summary=summary.scalar("nn/" + val, hist.history[val][ep]), global_step=epoch
            )
Beispiel #20
0
                        labels = test_feature_all['pre_label'][
                            start_ind:end_ind]
                        gt_concat_fea = np.concatenate(
                            (test_feature_all['sub_fea'][start_ind:end_ind],
                             test_feature_all['obj_fea'][start_ind:end_ind]),
                            axis=1)
                        rd_loss_temp, acc_temp, acc_each = vnet.val_predicate_fea_concate(
                            sess, gt_concat_fea, labels)
                        rd_loss_val = rd_loss_val + rd_loss_temp
                        acc_val += sum(acc_each)
                    print(
                        "whole-val: {0} rd_loss: {1}, acc: {2}, best_acc: {3}".
                        format(step, rd_loss_val / N_val, acc_val / N_val,
                               acc_val_all))

                    val_loss = summary.scalar('val_loss', rd_loss_val / N_val)
                    summary_writer.add_summary(val_loss, step)
                    val_accuracy = summary.scalar('val_acc', acc_val / N_val)
                    summary_writer.add_summary(val_accuracy, step)

                    if (acc_val / N_val) > acc_val_all:
                        save_path = model_path + '/' + args.data + '_vgg_' + format(
                            int(step), '04')
                        saver.save(sess, save_path)
                        saver.export_meta_graph(save_path + '.meta')
                        acc_val_all = acc_val / N_val
                        best_acc_val = step

                ###evaluation lowshot dataset ###
                elif args.mode == "lowshot":
                    rd_loss_val_lowshot = 0.0
Beispiel #21
0
    def train(self):
        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.netG, self.netsD, self.netMD, self.num_Ds, self.inception_model, start_count = load_network(
                self.gpus, self.num_batches)
        else:
            self.netG, self.netsD, self.num_Ds, self.inception_model, start_count = load_network(
                self.gpus, self.num_batches)

        avg_param_G = copy_G_params(self.netG)

        if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.image_cnn = Inception_v3()
            self.image_encoder = LINEAR_ENCODER()
            if not isinstance(self.image_cnn, torch.nn.DataParallel):
                self.image_cnn = nn.DataParallel(self.image_cnn)
            if not isinstance(self.image_encoder, torch.nn.DataParallel):
                self.image_encoder = nn.DataParallel(self.image_encoder)
            if cfg.DATASET_NAME == 'birds':
                self.image_encoder.load_state_dict(
                    torch.load(
                        "outputs/pre_train/birds/models/best_image_model.pth"))
            if cfg.DATASET_NAME == 'flowers':
                self.image_encoder.load_state_dict(
                    torch.load(
                        "outputs/pre_train/flowers/models/best_image_model.pth"
                    ))

            if cfg.CUDA:
                self.image_cnn = self.image_cnn.cuda()
                self.image_encoder = self.image_encoder.cuda()
            self.image_cnn.eval()
            self.image_encoder.eval()
            for p in self.image_cnn.parameters():
                p.requires_grad = False
            for p in self.image_encoder.parameters():
                p.requires_grad = False

        self.optimizerG, self.optimizersD = define_optimizers(
            self.netG, self.netsD)
        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.optimizerMD = optim.Adam(self.netMD.parameters(),
                                          lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                          betas=(0.5, 0.999))

        self.criterion = nn.BCELoss()
        self.real_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))
        self.same_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))
        self.wrong_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(2))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = Variable(
            torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.same_labels = self.same_labels.cuda()
            self.wrong_labels = self.wrong_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.similar_imgs, self.txt_embedding, self.class_ids = self.prepare_data(
                    data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = self.netG(
                    noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD
                #update MD network
                errMD = self.train_MDnet(count)
                errD_total += errMD
                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)  #

                # for inception score
                # pred = self.inception_model(self.fake_imgs[-1].detach())
                # predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.item())
                    summary_G = summary.scalar('G_loss', errG_total.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1
            if epoch % cfg.TRAIN.SAVE_EPOCH == 0:
                if cfg.TRAIN.COEFF.MD_LOSS > 0:
                    DIS_NET = [self.netsD, self.netMD]
                else:
                    DIS_NET = self.netsD
                save_model(self.netG, avg_param_G, DIS_NET, epoch,
                           self.model_dir)
            if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0:
                # Save images
                backup_para = copy_G_params(self.netG)
                load_params(self.netG, avg_param_G)
                #
                self.fake_imgs, _, _ = self.netG(fixed_noise,
                                                 self.txt_embedding)
                save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                 epoch, self.image_dir, self.summary_writer)
                #
                load_params(self.netG, backup_para)

                #############################
                #***during the training process, the paramerter of G are updated alone
                #**why in the generating stage, use the weighting parameter of G

                #############################
                """
                # Compute inception score
                if len(predictions) > 500:
                    predictions = np.concatenate(predictions, 0)
                    mean, std = compute_inception_score(predictions, 10)
                    # print('mean:', mean, 'std', std)
                    m_incep = summary.scalar('Inception_mean', mean)
                    self.summary_writer.add_summary(m_incep, count)
                    #
                    mean_nlpp, std_nlpp = negative_log_posterior_probability(predictions, 10)
                    m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                    self.summary_writer.add_summary(m_nlpp, count)
                    #
                    predictions = []
                """

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), kl_loss.item(), end_t - start_t))

        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            DIS_NET = [self.netsD, self.netMD]
        else:
            DIS_NET = self.netsD
        save_model(self.netG, avg_param_G, DIS_NET, epoch, self.model_dir)
        self.summary_writer.close()
Beispiel #22
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                    outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
                fake_feat = self.image_cnn(self.fake_imgs[i])
                fake_feat = self.image_encoder(fake_feat)
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
                real_feat = self.image_cnn(self.real_imgs[i])
                real_feat = self.image_encoder(real_feat)
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0:
                loss1, loss2 = batch_loss(real_feat, fake_feat, self.class_ids)
                errG_CC = loss1 + loss2
                errG_total = errG_total + errG_CC * cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS
            if cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0:
                loss1, loss2 = batch_loss(self.txt_embedding, fake_feat,
                                          self.class_ids)
                errG_SC = loss1 + loss2
                errG_total = errG_total + errG_SC * cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS
            if cfg.TRAIN.COEFF.MD_LOSS > 0 and i == (self.num_Ds - 1):
                outputs2 = self.netMD(real_feat, fake_feat)
                errMG = nn.CrossEntropyLoss()(outputs2, real_labels.long())
                errG_total = errG_total + errMG * cfg.TRAIN.COEFF.MD_LOSS
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.item())
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.item())
                    self.summary_writer.add_summary(sum_mu, global_step=count)

                    sum_cov = summary.scalar('G_like_cov2', like_cov2.item())
                    self.summary_writer.add_summary(sum_cov, global_step=count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.item())
                    self.summary_writer.add_summary(sum_mu, count)

                    sum_cov = summary.scalar('G_like_cov1', like_cov1.item())
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total
Beispiel #23
0
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM  # 100
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        detectron = Detectron()
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr
            #print('check 0')
            for i, data in enumerate(data_loader):
                ######################################################
                # (1) Prepare training data
                ######################################################
                #print('check 1')
                real_img_cpu, txt_embedding, caption = data
                caption = np.moveaxis(np.array(caption), 1, 0)

                #print('check 2')

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                #print('check 3')
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################

                #print(real_imgs.size())

                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)

                fake_img = fake_imgs.cpu().detach().numpy()
                #print(fake_img.shape)

                det_obj_list = detectron.get_labels(fake_img)

                fake_l = Variable(get_ohe(det_obj_list)).cuda()
                real_l = Variable(get_ohe(caption)).cuda()

                det_loss = nn.SmoothL1Loss()(fake_l, real_l)
                errG_total = det_loss + errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    summary_DET = summary.scalar('det_loss', det_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_DET, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Beispiel #24
0
    def train(self, data_loader):
        netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, feat_cpu = data
                real_img_cpu = real_img_cpu.squeeze(1).permute(0,3,1,2)
                feat_cpu = feat_cpu.squeeze(1).permute(0,3,1,2)
                real_imgs = Variable(real_img_cpu)
                feats = Variable(feat_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    feats = feats.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = feats
                fake_imgs = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                #errD, errD_real, errD_wrong, errD_fake = \
                errD, errD_real, errD_fake, errD_wrong = \
                    compute_discriminator_loss(netD, real_imgs, feats, fake_imgs,
                                               real_labels, fake_labels,
                                               self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, feats,
                                              real_labels, self.gpus)
                errG.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)

                    # save the image result for each epoch
                    inputs = feats
                    fake = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    #save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    save_img_results2(real_img_cpu, fake, feat_cpu, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(), 0.0,
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()
Beispiel #25
0
    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_fimgs.size(0)
        criterion, criterion_one = self.criterion, self.criterion_one

        if idx == 0:
            real_imgs = self.real_fimgs
            fake_imgs = self.fake_imgs[0]
            optD = self.optimizersD[0]

            netD = self.netsD[0]
            netD.zero_grad()
            real_logits = netD(real_imgs, self.alpha, self.masks.detach())
            fake_logits = netD(fake_imgs.detach(), self.alpha, self.aux_masks)
            real_labels = torch.ones_like(real_logits[1])
            fake_labels = torch.zeros_like(real_logits[1])

            errD_real = criterion_one(
                real_logits[1],
                real_labels)  # Real/Fake loss for the real image
            errD_fake = criterion_one(
                fake_logits[1],
                fake_labels)  # Real/Fake loss for the fake image
            errD0 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_GLB

            netD = self.netsD[3]
            netD.zero_grad()

            _fg = self.masks == 0
            rev_masks = torch.zeros_like(self.masks)
            rev_masks.masked_fill_(_fg, 1.0)
            real_logits = netD(real_imgs, self.alpha, rev_masks)

            fake_labels = torch.zeros_like(real_logits[1])
            ext, output, fnl_masks = real_logits
            weights_real = torch.ones_like(output)
            real_labels = torch.ones_like(output)

            # for i in range(batch_size):
            invalid_patch = fnl_masks != 0.0
            weights_real.masked_fill_(invalid_patch, 0.0)

            norm_fact_real = weights_real.sum()
            norm_fact_fake = weights_real.shape[0] * weights_real.shape[
                1] * weights_real.shape[2] * weights_real.shape[3]
            real_logits = ext, output

            fake_logits = netD(fake_imgs.detach(), self.alpha)

            errD_real_uncond = criterion(
                real_logits[1], real_labels
            )  # Real/Fake loss for 'real background' (on patch level)
            errD_real_uncond = torch.mul(
                errD_real_uncond, weights_real
            )  # Masking output units which correspond to receptive fields which lie within the boundin box
            errD_real_uncond = errD_real_uncond.mean()

            errD_fake_uncond = criterion(
                fake_logits[1], fake_labels
            )  # Real/Fake loss for 'fake background' (on patch level)
            errD_fake_uncond = errD_fake_uncond.mean()

            if norm_fact_real > 0:  # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
                errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /
                                                (norm_fact_real * 1.0))
            else:
                errD_real = errD_real_uncond

            errD_fake = errD_fake_uncond
            errD1 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_LCL

            # Background/foreground classification loss
            errD_real_uncond_classi = criterion(real_logits[0], weights_real)
            errD_real_uncond_classi = errD_real_uncond_classi.mean()
            errD_classi = errD_real_uncond_classi * cfg.TRAIN.BG_CLASSI_WT

            # print(errD0, errD1)
            # sys.exit(0)

            errD = errD0 + errD1 + errD_classi

        elif idx == 2:  # Discriminator is only trained in background and child stage. (NOT in parent stage)
            netD, optD = self.netsD[2], self.optimizersD[2]
            real_imgs = self.real_cimgs
            fake_imgs = self.fake_imgs[2]
            netD.zero_grad()
            real_logits = netD(real_imgs, self.alpha)
            fake_logits = netD(fake_imgs.detach(), self.alpha)
            real_labels = torch.ones_like(real_logits[1])
            fake_labels = torch.zeros_like(real_logits[1])

            errD_real = criterion_one(
                real_logits[1],
                real_labels)  # Real/Fake loss for the real image
            errD_fake = criterion_one(
                fake_logits[1],
                fake_labels)  # Real/Fake loss for the fake image
            errD = errD_real + errD_fake

        errD.backward()
        optD.step()

        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.item())
            self.summary_writer.add_summary(summary_D, count)
            summary_D_real = summary.scalar('D_loss_real_%d' % idx,
                                            errD_real.item())
            self.summary_writer.add_summary(summary_D_real, count)
            summary_D_fake = summary.scalar('D_loss_fake_%d' % idx,
                                            errD_fake.item())
            self.summary_writer.add_summary(summary_D_fake, count)

        return errD
Beispiel #26
0
    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = \
                        self.netG(fixed_noise, self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data[0],
                   errG_total.data[0], kl_loss.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()
    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        if cfg.TRAIN.ADAM:
            optimizerD = \
                optim.Adam(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
            optimizerG = optim.Adam(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR,
                                    betas=(0.5, 0.999))
        else:
            optimizerD = \
                optim.RMSprop(netD.parameters(),
                           lr=cfg.TRAIN.DISCRIMINATOR_LR)
            optimizerG = \
                optim.RMSprop(netG_para,
                                    lr=cfg.TRAIN.GENERATOR_LR)

        cnn = models.vgg19(pretrained=True).features
        cnn = nn.Sequential(*list(cnn.children())[0:28])
        gram = GramMatrix()
        if cfg.CUDA:
            cnn.cuda()
            gram.cuda()
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                else:
                    _, fake_imgs, mu, logvar = netG(txt_embedding, noise)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus, cfg.CUDA)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus, cfg.CUDA)
                kl_loss = KL_loss(mu, logvar)
                pixel_loss = PIXEL_loss(real_imgs, fake_imgs)
                if cfg.CUDA:
                    fake_features = nn.parallel.data_parallel(
                        cnn, fake_imgs.detach(), self.gpus)
                    real_features = nn.parallel.data_parallel(
                        cnn, real_imgs.detach(), self.gpus)
                else:
                    fake_features = cnn(fake_imgs)
                    real_features = cnn(real_imgs)
                active_loss = ACT_loss(fake_features, real_features)
                text_loss = TEXT_loss(gram, fake_features, real_features,
                                      cfg.TRAIN.COEFF.TEXT)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + \
                                pixel_loss * cfg.TRAIN.COEFF.PIX + \
                                active_loss * cfg.TRAIN.COEFF.ACT +\
                                text_loss
                errG_total.backward()
                optimizerG.step()
                count = count + 1
                if i % 100 == 0:

                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    summary_Pix = summary.scalar('Pixel_loss',
                                                 pixel_loss.data[0])
                    summary_Act = summary.scalar('Act_loss',
                                                 active_loss.data[0])
                    summary_Text = summary.scalar('Text_loss',
                                                  text_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)
                    self.summary_writer.add_summary(summary_Pix, count)
                    self.summary_writer.add_summary(summary_Act, count)
                    self.summary_writer.add_summary(summary_Text, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    if cfg.CUDA:
                        lr_fake, fake, _, _ = \
                            nn.parallel.data_parallel(netG, inputs, self.gpus)
                    else:
                        lr_fake, fake, _, _ = netG(txt_embedding, fixed_noise)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_Pixel: %.4f
                                     Loss_Activ: %.4f Loss_Text: %.4f
                                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                                     Total Time: %.2fsec
                                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.data[0],
                 errG.data[0], kl_loss.data[0], pixel_loss.data[0],
                 active_loss.data[0], text_loss.data[0], errD_real, errD_wrong,
                 errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()