Esempio n. 1
0
    def sample(self,
               datapath,
               num_samples=25,
               stage=1,
               draw_bbox=True,
               max_objects=3):
        from PIL import Image, ImageDraw, ImageFont
        import pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = cfg.IMG_DIR
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath + "val_captions.t7")

        captions_list = t_file.raw_txt

        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        label, bbox = load_validation_data(datapath)

        filepath = os.path.join(datapath, 'val_filename.txt')
        with open(filepath, 'r') as f:
            # filenames = pickle.load(f)
            filenames = f.readlines()

        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_visualize_bbox"
        print("saving to:", save_dir)
        mkdir_p(save_dir)

        if cfg.CUDA:
            if cfg.STAGE == 1:
                bbox = bbox.cuda()
            elif cfg.STAGE == 2:
                bbox = [bbox.clone().cuda(), bbox.cuda()]
            label = label.cuda()

        #######################################
        if cfg.STAGE == 1:
            bbox_ = bbox.clone()
        elif cfg.STAGE == 2:
            bbox_ = bbox[0].clone()

        if cfg.STAGE == 1:
            bbox = bbox.view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)
        elif cfg.STAGE == 2:
            _bbox = bbox.view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(_bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)

            _bbox = bbox.view(-1, 4)
            transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                _bbox)
            transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                num_embeddings, max_objects, 2, 3)
            transf_matrices_s2 = compute_transformation_matrix(_bbox)
            transf_matrices_s2 = transf_matrices_s2.view(
                num_embeddings, max_objects, 2, 3)

        # produce one-hot encodings of the labels
        _labels = label.long()
        # remove -1 to enable one-hot converting
        _labels[_labels < 0] = 80
        # label_one_hot = torch.cuda.FloatTensor(num_embeddings, max_objects, 81).fill_(0)
        label_one_hot = torch.FloatTensor(num_embeddings, max_objects,
                                          81).fill_(0)
        label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64 if stage == 1 else 256

        for count in range(num_samples):
            index = int(np.random.randint(0, num_embeddings, 1))
            key = filenames[index].strip('\n')
            img_name = img_dir + "/" + key + ".jpg"
            # img = Image.open(img_name).convert('RGB').resize((imsize, imsize), Image.ANTIALIAS)
            # val_image = torchvision.transforms.functional.to_tensor(img)
            # val_image = val_image.view(1, 3, imsize, imsize)
            # val_image = (val_image - 0.5) * 2

            embeddings_batch = embeddings[index]
            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            embeddings_batch = np.reshape(embeddings_batch,
                                          (1, 1024)).repeat(9, 0)
            transf_matrices_inv_batch = transf_matrices_inv_batch.view(
                1, 3, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, 3,
                                                           81).repeat(9, 1, 1)

            if cfg.STAGE == 2:
                transf_matrices_s2_batch = transf_matrices_s2[index]
                transf_matrices_s2_batch = transf_matrices_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2[index]
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)

            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            # inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch)
            if cfg.STAGE == 1:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          label_one_hot_batch)
            elif cfg.STAGE == 2:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          transf_matrices_s2_batch,
                          transf_matrices_inv_s2_batch, label_one_hot_batch)
            with torch.no_grad():
                # _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)
                _, fake_imgs, mu, logvar, _ = netG(*inputs)

            data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
            # data_img[0] = val_image
            data_img[0:10] = fake_imgs

            if draw_bbox:
                for idx in range(3):
                    x, y, w, h = tuple(
                        [int(imsize * x) for x in bbox_[index, idx]])
                    w = imsize - 1 if w > imsize - 1 else w
                    h = imsize - 1 if h > imsize - 1 else h
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y + h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            print('Caption: ', captions_list[index].decode('utf-8'))
            vutils.save_image(data_img,
                              '{}/{}.png'.format(
                                  save_dir,
                                  captions_list[index].decode('utf-8')),
                              normalize=True,
                              nrow=5)

        print("Saved {} files to {}".format(count + 1, save_dir))
Esempio n. 2
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                # label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0)
                label_one_hot = torch.FloatTensor(noise.shape[0], max_objects,
                                                  81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)
                        lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                            netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
Esempio n. 3
0
    def sample(self, datapath, num_samples=25, draw_bbox=True, num_digits_per_img=3, change_bbox_size=False):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = os.path.join(datapath, "normal", "imgs/")
        netG, _ = self.load_network_stageI()
        netG.eval()
        test_set_size = 10000

        label, bbox = load_validation_data(datapath)
        if num_digits_per_img < 3:
            label = label[:, :num_digits_per_img, :]
            bbox = bbox[:, :num_digits_per_img, ...]
        elif num_digits_per_img > 3:
            def get_one_hot(targets, nb_classes):
                res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
                return res.reshape(list(targets.shape) + [nb_classes])

            labels_sample = np.random.randint(0, 10, size=(bbox.shape[0], num_digits_per_img-3))
            labels_sample = get_one_hot(labels_sample, 10)
            labels_new = np.zeros((label.shape[0], num_digits_per_img, 10))
            labels_new[:, :3, :] = label
            labels_new[:, 3:, :] = labels_sample
            label = torch.from_numpy(labels_new)

            bboxes_x = np.random.random((bbox.shape[0], num_digits_per_img-3, 1))
            bboxes_y = np.random.random((bbox.shape[0], num_digits_per_img-3, 1))
            bboxes_w = np.random.randint(10, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0
            bboxes_h = np.random.randint(16, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0

            bbox_new_concat = np.concatenate((bboxes_x, bboxes_y, bboxes_w, bboxes_h), axis=2)
            bbox_new = np.zeros([bbox.shape[0], num_digits_per_img, 4])
            bbox_new[:, :3, :] = bbox
            bbox_new[:, 3:, :] = bbox_new_concat
            bbox = torch.from_numpy(bbox_new)

        if change_bbox_size:
            bbox_idx = np.random.randint(0, bbox.shape[1])
            scale_x = np.random.random(bbox.shape[0])
            scale_x[scale_x < 0.5] = 0.5
            scale_y = np.random.random(bbox.shape[0])
            scale_y[scale_y < 0.5] = 0.5

            bbox[:, bbox_idx, 2] *= torch.from_numpy(scale_x)
            bbox[:, bbox_idx, 3] *= torch.from_numpy(scale_y)

        filepath = os.path.join(datapath, "normal", 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(num_digits_per_img) + "_digits"
        if change_bbox_size:
            save_dir += "_change_bbox_size"
        print("Saving {} to {}:".format(num_samples, save_dir))
        mkdir_p(save_dir)
        if cfg.CUDA:
            bbox = bbox.cuda()
            label_one_hot = label.cuda().float()

        #######################################
        bbox_ = bbox.clone()
        bbox = bbox.view(-1, 4)
        transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float()
        transf_matrices_inv = transf_matrices_inv.view(test_set_size, num_digits_per_img, 2, 3)
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64

        for count in range(num_samples):
            index = int(np.random.randint(0, test_set_size, 1))
            key = filenames[index].split("/")[-1]
            img_name = img_dir + key
            img = Image.open(img_name)
            val_image = torchvision.transforms.functional.to_tensor(img)
            val_image = val_image.view(1, 1, imsize, imsize)
            val_image = (val_image - 0.5) * 2

            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            transf_matrices_inv_batch = transf_matrices_inv_batch.view(1, num_digits_per_img, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, num_digits_per_img, 10).repeat(9, 1, 1)

            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch, num_digits_per_img)
            _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)

            data_img = torch.FloatTensor(20, 1, imsize, imsize).fill_(0)
            data_img[0] = val_image
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(num_digits_per_img):
                    x, y, w, h = tuple([int(imsize*x) for x in bbox_[index, idx]])
                    w = imsize-1 if w > imsize-1 else w
                    h = imsize-1 if h > imsize-1 else h
                    while x + w >= 64:
                        x -= 1
                        w -= 1
                    while y + h >= 64:
                        y -= 1
                        h -= 1
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y+h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            # write digit identities into image
            text_img = Image.new('L', (imsize*10, imsize), color = 'white')
            d = ImageDraw.Draw(text_img)
            label = label_one_hot_batch[0]
            label = label.cpu().numpy()
            label = np.argmax(label, axis=1)
            label = ", ".join([str(label[_]) for _ in range(num_digits_per_img)])
            d.text((10,10), label)
            text_img = torchvision.transforms.functional.to_tensor(text_img)
            text_img = torch.chunk(text_img, 10, 2)
            text_img = torch.cat([text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0)
            data_img[10:] = text_img
            vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10)
        print("Saved {} files to {}".format(count+1, save_dir))
Esempio n. 4
0
    def train(self, data_loader):
        netG, netD = self.load_network_stageI()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))

        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))

        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))

        print("Starting training...")
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label = data

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    bbox = bbox.cuda()
                    label_one_hot = label.cuda().float()

                bbox = bbox.view(-1, 4)
                transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float()
                transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], self.max_objects, 2, 3)
                transf_matrices = compute_transformation_matrix(bbox).float()
                transf_matrices = transf_matrices.view(real_imgs.shape[0], self.max_objects, 2, 3)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, transf_matrices_inv, label_one_hot)
                # _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)
                _, fake_imgs = netG(noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               label_one_hot, transf_matrices, transf_matrices_inv, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot,
                                              transf_matrices, transf_matrices_inv, self.gpus)
                errG_total = errG
                errG_total.backward()
                optimizerG.step()

                ############################
                # (3) Log results
                ###########################
                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        inputs = (noise, transf_matrices_inv, label_one_hot)
                        lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                        real_img_cpu = pad_imgs(real_img_cpu)
                        fake = pad_imgs(fake)
                        save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch, self.image_dir)
            with torch.no_grad():
                inputs = (noise, transf_matrices_inv, label_one_hot)
                lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                real_img_cpu = pad_imgs(real_img_cpu)
                fake = pad_imgs(fake)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()
Esempio n. 5
0
    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        ####
        startpoint = -1
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict["optimD"])
            optimizerG.load_state_dict(state_dict["optimG"])
            startpoint = state_dict["epoch"]
            print(startpoint)
            print('Load Optim and optimizers as : ', cfg.NET_G)
        ####

        count = 0
        drive_count = 0
        for epoch in range(startpoint + 1, self.max_epoch):
            print('epoch : ', epoch, ' drive_count : ', drive_count)
            epoch_start_time = time.time()
            print(epoch)
            start_t = time.time()
            start_t500 = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            time_to_i = time.time()
            for i, data in enumerate(data_loader, 0):
                # if i >= 3360 :
                #     print ('Last Batches : ' , i)
                # if i < 10 :
                #     print ('first Batches : ' , i)
                # if i == 0 :
                #     print ('Startig! Batch ',i,'from total of 2070' )
                # if i % 10 == 0 and i!=0:
                #     end_t500 = time.time()
                #     print ('Batch Number : ' , i ,' |||||  Toatal Time : ' , (end_t500 - start_t500))
                #     start_t500 = time.time()
                ######################################################
                # (1) Prepare training data
                # if i < 10 :
                #     print (" (1) Prepare training data for batch : " , i)
                ######################################################
                #print ("Prepare training data for batch : " , i)
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                if cfg.CUDA:
                    label_one_hot = torch.cuda.FloatTensor(
                        noise.shape[0], max_objects, 81).fill_(0)
                else:
                    label_one_hot = torch.FloatTensor(noise.shape[0],
                                                      max_objects, 81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # # (2) Generate fake images
                # if i < 10 :
                #     print ("(2)Generate fake images")
                ######################################################

                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                if cfg.CUDA:
                    _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                        netG, inputs, self.gpus)
                else:
                    print('Hiiiiiiiiiiii')
                    _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise,
                                                       transf_matrices_inv,
                                                       label_one_hot)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # # (3) Update D network
                # if i < 10 :
                #     print("(3) Update D network")
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # # (4) Update G network
                # if i < 10 :
                #     print ("(4) Update G network")
                ###########################
                netG.zero_grad()
                # if i < 10 :
                #     print ("netG.zero_grad")
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    # if i < 10 :
                    #     print ("cgf.STAGE = " , cfg.STAGE)
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                    # if i < 10 :
                    #     print("errG : ",errG)
                kl_loss = KL_loss(mu, logvar)
                # if i < 10 :
                #     print ("kl_loss = " , kl_loss)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                # if i < 10 :
                #     print (" errG_total = " , errG_total )
                errG_total.backward()
                # if i < 10 :
                #     print ("errG_total.backward() ")
                optimizerG.step()
                # if i < 10 :
                #     print ("optimizerG.step() " )

                #print (" i % 500 == 0 :  " , i % 500 == 0 )
                end_t = time.time()
                #print ("batch time : " , (end_t - start_t))
                if i % 500 == 0:
                    #print (" i % 500 == 0" , i % 500 == 0 )
                    count += 1
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    print('epoch     :  ', epoch)
                    print('count     :  ', count)
                    print('  i       :  ', i)
                    print('Time to i : ', time.time() - time_to_i)
                    time_to_i = time.time()
                    print('D_loss : ', errD.item())
                    print('D_loss_real : ', errD_real)
                    print('D_loss_wrong : ', errD_wrong)
                    print('D_loss_fake : ', errD_fake)
                    print('G_loss : ', errG.item())
                    print('KL_loss : ', kl_loss.item())
                    print('generator_lr : ', generator_lr)
                    print('discriminator_lr : ', discriminator_lr)
                    print('lr_decay_step : ', lr_decay_step)

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)

                        if cfg.CUDA:
                            lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                                netG, inputs, self.gpus)
                        else:
                            lr_fake, fake, _, _, _ = netG(
                                txt_embedding, noise, transf_matrices_inv,
                                label_one_hot)

                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
                if i % 100 == 0:
                    drive_count += 1
                    self.drive_summary_writer.add_summary(
                        summary_D, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_r, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_w, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_D_f, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_G, drive_count)
                    self.drive_summary_writer.add_summary(
                        summary_KL, drive_count)

            #print (" with torch.no_grad(): "  )
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 )
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                    #print (" inputs " , inputs )
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                #print (" lr_fake, fake " , lr_fake, fake )
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , )

                #print (" lr_fake is not None: " , lr_fake is not None )
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
                    #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " )
                    #end_t = time.time()
                    #print ("batch time : " , (end_t - start_t))
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)

            print("keyTime |||||||||||||||||||||||||||||||")
            print("epoch_time : ", time.time() - epoch_start_time)
            print("KeyTime |||||||||||||||||||||||||||||||")

        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()