Пример #1
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = AnimeFaceDataset(
            avatar_tag_dat_path=avatar_tag_dat_path,
            transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        checkpoint = torch.load(new_model_path)
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1.0)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))

            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (avatar_tag, avatar_img) in enumerate(self.data_loader):
                iteration += 1
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration
                avatar_img = Variable(avatar_img).to(device)
                avatar_tag = Variable(torch.FloatTensor(avatar_tag)).to(device)
                # D : G = 2 : 1
                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p, tag_p = self.D(avatar_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                real_label_loss = self.label_criterion(label_p, label)
                real_tag_loss = self.tag_criterion(tag_p, avatar_tag)
                real_loss_sum = real_label_loss * lambda_adv / 2.0 + real_tag_loss * lambda_adv / 2.0
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_tag_loss = self.tag_criterion(fake_tag_p, fake_tag)
                fake_loss_sum = fake_label_loss * lambda_adv / 2.0 + fake_tag_loss * lambda_adv / 2.0
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.5. gradient penalty
                # https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
                alpha_size = [1] * avatar_img.dim()
                alpha_size[0] = avatar_img.size(0)
                alpha = torch.rand(alpha_size).to(device)
                x_hat = Variable(alpha * avatar_img.data + (1 - alpha) * \
                                 (avatar_img.data + 0.5 * avatar_img.data.std() * Variable(torch.rand(avatar_img.size())).to(device)),
                                 requires_grad=True).to(device)
                pred_hat, pred_tag = self.D(x_hat)
                gradients = grad(outputs=pred_hat,
                                 inputs=x_hat,
                                 grad_outputs=torch.ones(
                                     pred_hat.size()).to(device),
                                 create_graph=True,
                                 retain_graph=True,
                                 only_inputs=True)[0].view(x_hat.size(0), -1)
                gradient_penalty = lambda_gp * (
                    (gradients.norm(2, dim=1) - 1)**2).mean()
                gradient_penalty.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator gradient penalty'] = float(
                            gradient_penalty)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(1.0)

                # 2.2. calc loss
                label_loss_g = self.label_criterion(fake_label_p, label)
                tag_loss_g = self.tag_criterion(fake_tag_p, fake_tag)
                loss_g = label_loss_g * lambda_adv / 2.0 + tag_loss_g * lambda_adv / 2.0
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % verbose_T == 0:
                    vutils.save_image(
                        avatar_img.data.view(batch_size, 3, avatar_img.size(2),
                                             avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = fake_generator()
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, avatar_img.size(2),
                                           avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            self.epoch += 1
Пример #2
0
    hr_test, lr_test = next(iter(test_dataloader))
    vutils.save_image(hr_test, f"{results_dir}/hr.png", normalize=True)
    vutils.save_image(lr_test, f"{results_dir}/lr.png", normalize=True)
    with open(f"{models_dir}/args.yml", "w") as f:
        yaml.dump(args, f)
    for epoch in range(args.epochs):
        training_bar = tqdm(train_dataloader)
        stats: defaultdict = defaultdict(float)
        for i, (hr_sample, lr_sample) in enumerate(training_bar, 1):
            gen_net = gen_net.train()

            if epoch >= args.start_discriminator:
                hr_sample = hr_sample.to(device)
                lr_sample = lr_sample.to(device)
                sr_sample = gen_net(lr_sample)
                dis_net.zero_grad()
                sr_dis = dis_net(sr_sample).mean()
                hr_dis = dis_net(hr_sample).mean()

                gradient_penalty = (
                    calc_grad_pen(dis_net, hr_sample, sr_sample, device) *
                    args.wgan_gp_lambda)

                d_loss = sr_dis + 1 - hr_dis
                d_loss_total = d_loss + gradient_penalty
                d_loss_total.backward()
                opt_dis.step()
            else:
                d_loss = 0
                gradient_penalty = 0
            gen_net.zero_grad()
Пример #3
0
            #         realLabelGT, fakeLabelGT = fakeLabelGT, realLabelGT

            # Prepare a (batch_size * 1 * w * h) tensor indicates class for discriminator of CGAN
            repeatRealClass = None
            if args.GAN_TYPE == "CGAN":
                repeatRealClass = F.one_hot(realClass, num_classes=args.NUM_CLASSES)
                repeatRealClass = repeatRealClass.unsqueeze(-1).unsqueeze(-1)
                repeatRealClass = repeatRealClass.repeat(1, 1, *realImage.size()[-2:])
                repeatRealClass = repeatRealClass.type(torch.FloatTensor)

                repeatRealClass = repeatRealClass.to(device)

            ### Update Discriminator ### 

            # Train with real
            discriminator.zero_grad()
            realImage = realImage.to(device)            
            
            pred = discriminator(realImage, repeatRealClass)
            if args.GAN_TYPE == "ACGAN":
                predLabel, predClass = pred
            else:
                predLabel = pred
            realClass = realClass.to(device)
            realLabel = realLabel.to(device)
            
            # Compute loss of D with real inputs
            lossRealLabelD = criterionLabel(predLabel, realLabel)
            lossRealClassD = criterionClass(predClass, realClass) if args.GAN_TYPE == "ACGAN" else 0
            lossRealD = lossRealLabelD + lossRealClassD
Пример #4
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])

            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.a_D = alexnet.alexnet(num_classes=tag_size).to(device)
        self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(),
                                              lr=learning_rate,
                                              betas=(beta_1, .999))
        # self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        # self.tag_criterion = nn.BCEWithLogitsLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = utils.read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        if torch.cuda.is_available():
            checkpoint = torch.load(new_model_path)
        else:
            checkpoint = torch.load(
                new_model_path,
                map_location='cuda' if torch.cuda.is_available() else 'cpu')
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (food_tag, food_img) in enumerate(self.data_loader):
                iteration += 1
                if food_img.shape[0] != batch_size:
                    logging.warn('Batch size not satisfied. Ignoring.')
                    continue
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration

                food_img = Variable(food_img).to(device)
                # 0. training assistant D
                self.a_D.zero_grad()
                a_D_feat = self.a_D(food_img)

                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p = self.D(food_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                # real_label_loss = self.label_criterion(label_p, label)
                real_label_loss = F.binary_cross_entropy(label_p, label)
                real_loss_sum = real_label_loss
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                # fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_label_loss = F.binary_cross_entropy(fake_label_p, label)
                # TODO:
                fake_loss_sum = fake_label_loss
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        print('predicted fake label: {}'.format(fake_label_p))
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p = self.D(fake_img)
                label.data.fill_(1.0)

                a_D_feat = self.a_D(fake_img)
                feat_loss = F.binary_cross_entropy(a_D_feat, fake_tag)

                # 2.2. calc loss
                # label_loss_g = self.label_criterion(fake_label_p, label)
                label_loss_g = F.binary_cross_entropy(fake_label_p, label)
                loss_g = label_loss_g
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % 10000 == 0:
                    torch.save(
                        {
                            'epoch': self.epoch,
                            'D': self.D.state_dict(),
                            'G': self.G.state_dict(),
                            'optimizer_D': self.optimizer_D.state_dict(),
                            'optimizer_G': self.optimizer_G.state_dict(),
                        },
                        '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                      str(iteration).zfill(8)))
                    logger.info('Checkpoint saved in: {}'.format(
                        '{}/checkpoint_{}.tar'.format(
                            model_dump_path,
                            str(iteration).zfill(8))))

                if iteration % verbose_T == 0:
                    vutils.save_image(
                        food_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = utils.fake_generator(
                        batch_size, noise_size, device)
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))
            self.epoch += 1