Exemplo n.º 1
0
 def test_scalar_new_style(self):
     scalar = summary.scalar('test_scalar', 1.0, new_style=True)
     self.assertTrue(compare_proto(scalar, self))
     with self.assertRaises(AssertionError):
         summary.scalar('test_scalar2',
                        torch.Tensor([1, 2, 3]),
                        new_style=True)
    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            netD = self.netsD[i]
            outputs = netD(self.fake_imgs[i])
            errG = criterion(outputs[0], real_labels)
            # errG = self.stage_coeff[i] * errG
            errG_total = errG_total + errG
            if flag == 0:
                #summary_G = summary.scalar('G_loss%d' % i, errG.data[0])
                summary_G = summary.scalar('G_loss%d' % i, errG.data)
                self.summary_writer.add_summary(summary_G, count)

        # Compute color preserve losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1

            if flag == 0:
                #sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                sum_mu = summary.scalar('G_like_mu2', like_mu2.data)
                self.summary_writer.add_summary(sum_mu, count)
                #sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                sum_cov = summary.scalar('G_like_cov2', like_cov2.data)
                self.summary_writer.add_summary(sum_cov, count)
                if self.num_Ds > 2:
                    #sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data)
                    self.summary_writer.add_summary(sum_mu, count)
                    #sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data)
                    self.summary_writer.add_summary(sum_cov, count)

        errG_total.backward()
        self.optimizerG.step()
        return errG_total
Exemplo n.º 3
0
    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], self.onehot_t)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                    outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.item())
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.item())
                    self.summary_writer.add_summary(sum_mu, global_step=count)

                    sum_cov = summary.scalar('G_like_cov2', like_cov2.item())
                    self.summary_writer.add_summary(sum_cov, global_step=count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.item())
                    self.summary_writer.add_summary(sum_mu, count)

                    sum_cov = summary.scalar('G_like_cov1', like_cov1.item())
                    self.summary_writer.add_summary(sum_cov, count)

        #kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total  #+ kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return None, errG_total
Exemplo n.º 4
0
    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        #
        netD.zero_grad()
        #
        real_logits = netD(real_imgs)
        fake_logits = netD(fake_imgs.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        #
        errD = errD_real + errD_fake
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.item())
            self.summary_writer.add_summary(summary_D, count)
        return errD
Exemplo n.º 5
0
    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        """Add scalar data to summary.

        Args:
            tag (string): Data identifier
            scalar_value (float or string/blobname): Value to save
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
              with seconds after epoch of event

        Examples::

            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter()
            x = range(100)
            for i in x:
                writer.add_scalar('y=2x', i * 2, i)
            writer.close()

        Expected result:

        .. image:: _static/img/tensorboard/add_scalar.png
           :scale: 50 %

        """

        self._get_file_writer().add_summary(
            scalar(tag, scalar_value), global_step, walltime)
Exemplo n.º 6
0
    def add_scalars(self,
                    main_tag,
                    tag_scalar_dict,
                    global_step=None,
                    walltime=None):
        """Adds many scalar data to summary.

        Args:
            main_tag (string): The parent name for the tags
            tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
              seconds after epoch of event

        Examples::

            from torch.utils.tensorboard import SummaryWriter
            writer = SummaryWriter()
            r = 5
            for i in range(100):
                writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
                                                'xcosx':i*np.cos(i/r),
                                                'tanx': np.tan(i/r)}, i)
            writer.close()
            # This call adds three values to the same scalar plot with the tag
            # 'run_14h' in TensorBoard's scalar section.

        Expected result:

        .. image:: _static/img/tensorboard/add_scalars.png
           :scale: 50 %

        """
        torch._C._log_api_usage_once("tensorboard.logging.add_scalars")
        walltime = time.time() if walltime is None else walltime
        fw_logdir = self._get_file_writer().get_logdir()
        for tag, scalar_value in tag_scalar_dict.items():
            fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag
            if fw_tag in self.all_writers.keys():
                fw = self.all_writers[fw_tag]
            else:
                fw = FileWriter(fw_tag, self.max_queue, self.flush_secs,
                                self.filename_suffix)
                self.all_writers[fw_tag] = fw
            if self._check_caffe2_blob(scalar_value):
                scalar_value = workspace.FetchBlob(scalar_value)
            fw.add_summary(scalar(main_tag, scalar_value), global_step,
                           walltime)
    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            #summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            summary_D = summary.scalar('D_loss%d' % idx, errD.data)
            self.summary_writer.add_summary(summary_D, count)
        return errD
Exemplo n.º 8
0
    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.item())
                    summary_G = summary.scalar('G_loss', errG_total.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = \
                        self.netG(fixed_noise, self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                     count, self.image_dir, self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.item(), errG_total.item(),
                     kl_loss.item(), end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()
Exemplo n.º 9
0
 def test_scalar_new_style(self):
     scalar = summary.scalar('test_scalar', 1.0, new_style=True)
     self.assertTrue(compare_proto(scalar, self))
Exemplo n.º 10
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####
        startpoint = -1
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict["optimD"])
            optimizerG.load_state_dict(state_dict["optimG"])
            startpoint = state_dict["epoch"]
            print(startpoint)
            print('Load Optim and optimizers as : ', cfg.NET_G)
        ####

        count = 0
        drive_count = 0
        for epoch in range(startpoint + 1, self.max_epoch):
            print('epoch : ', epoch, ' drive_count : ', drive_count)
            epoch_start_time = time.time()
            print(epoch)
            start_t = time.time()
            start_t500 = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            time_to_i = time.time()
            for i, data in enumerate(data_loader, 0):
                # if i >= 3360 :
                #     print ('Last Batches : ' , i)
                # if i < 10 :
                #     print ('first Batches : ' , i)
                # if i == 0 :
                #     print ('Startig! Batch ',i,'from total of 2070' )
                # if i % 10 == 0 and i!=0:
                #     end_t500 = time.time()
                #     print ('Batch Number : ' , i ,' |||||  Toatal Time : ' , (end_t500 - start_t500))
                #     start_t500 = time.time()
                ######################################################
                # (1) Prepare training data
                # if i < 10 :
                #     print (" (1) Prepare training data for batch : " , i)
                ######################################################
                #print ("Prepare training data for batch : " , i)
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                if cfg.CUDA:
                    label_one_hot = torch.cuda.FloatTensor(
                        noise.shape[0], max_objects, 81).fill_(0)
                else:
                    label_one_hot = torch.FloatTensor(noise.shape[0],
                                                      max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # # (2) Generate fake images
                # if i < 10 :
                #     print ("(2)Generate fake images")
                ######################################################

                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                else:
                    print('Hiiiiiiiiiiii')
                    _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise,
                                                       transf_matrices_inv,
                                                       label_one_hot)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # # (3) Update D network
                # if i < 10 :
                #     print("(3) Update D network")
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # # (4) Update G network
                # if i < 10 :
                #     print ("(4) Update G network")
                ###########################
                netG.zero_grad()
                # if i < 10 :
                #     print ("netG.zero_grad")
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    # if i < 10 :
                    #     print ("cgf.STAGE = " , cfg.STAGE)
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                    # if i < 10 :
                    #     print("errG : ",errG)
                kl_loss = KL_loss(mu, logvar)
                # if i < 10 :
                #     print ("kl_loss = " , kl_loss)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                # if i < 10 :
                #     print (" errG_total = " , errG_total )
                errG_total.backward()
                # if i < 10 :
                #     print ("errG_total.backward() ")
                optimizerG.step()
                # if i < 10 :
                #     print ("optimizerG.step() " )

                #print (" i % 500 == 0 :  " , i % 500 == 0 )
                end_t = time.time()
                #print ("batch time : " , (end_t - start_t))
                if i % 500 == 0:
                    #print (" i % 500 == 0" , i % 500 == 0 )
                    count += 1
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    print('epoch     :  ', epoch)
                    print('count     :  ', count)
                    print('  i       :  ', i)
                    print('Time to i : ', time.time() - time_to_i)
                    time_to_i = time.time()
                    print('D_loss : ', errD.item())
                    print('D_loss_real : ', errD_real)
                    print('D_loss_wrong : ', errD_wrong)
                    print('D_loss_fake : ', errD_fake)
                    print('G_loss : ', errG.item())
                    print('KL_loss : ', kl_loss.item())
                    print('generator_lr : ', generator_lr)
                    print('discriminator_lr : ', discriminator_lr)
                    print('lr_decay_step : ', lr_decay_step)

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)

                        if cfg.CUDA:
                            lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                                netG, inputs, self.gpus)
                        else:
                            lr_fake, fake, _, _, _ = netG(
                                txt_embedding, noise, transf_matrices_inv,
                                label_one_hot)

                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
                if i % 100 == 0:
                    drive_count += 1
                    self.drive_summary_writer.add_summary(
                        summary_D, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_r, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_w, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_f, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_G, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_KL, drive_count)

            #print (" with torch.no_grad(): "  )
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 )
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                    #print (" inputs " , inputs )
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                #print (" lr_fake, fake " , lr_fake, fake )
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , )

                #print (" lr_fake is not None: " , lr_fake is not None )
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
                    #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " )
                    #end_t = time.time()
                    #print ("batch time : " , (end_t - start_t))
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)

            print("keyTime |||||||||||||||||||||||||||||||")
            print("epoch_time : ", time.time() - epoch_start_time)
            print("KeyTime |||||||||||||||||||||||||||||||")

        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
Exemplo n.º 11
0
 def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
     torch._C._log_api_usage_once("tensorboard.logging.add_scalar")
     if self._check_caffe2_blob(scalar_value):
         scalar_value = workspace.FetchBlob(scalar_value)
     self._get_file_writer().add_summary(scalar(tag, scalar_value),
                                         global_step, walltime)