示例#1
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = AnimeFaceDataset(
            avatar_tag_dat_path=avatar_tag_dat_path,
            transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        checkpoint = torch.load(new_model_path)
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1.0)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))

            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (avatar_tag, avatar_img) in enumerate(self.data_loader):
                iteration += 1
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration
                avatar_img = Variable(avatar_img).to(device)
                avatar_tag = Variable(torch.FloatTensor(avatar_tag)).to(device)
                # D : G = 2 : 1
                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p, tag_p = self.D(avatar_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                real_label_loss = self.label_criterion(label_p, label)
                real_tag_loss = self.tag_criterion(tag_p, avatar_tag)
                real_loss_sum = real_label_loss * lambda_adv / 2.0 + real_tag_loss * lambda_adv / 2.0
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_tag_loss = self.tag_criterion(fake_tag_p, fake_tag)
                fake_loss_sum = fake_label_loss * lambda_adv / 2.0 + fake_tag_loss * lambda_adv / 2.0
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.5. gradient penalty
                # https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
                alpha_size = [1] * avatar_img.dim()
                alpha_size[0] = avatar_img.size(0)
                alpha = torch.rand(alpha_size).to(device)
                x_hat = Variable(alpha * avatar_img.data + (1 - alpha) * \
                                 (avatar_img.data + 0.5 * avatar_img.data.std() * Variable(torch.rand(avatar_img.size())).to(device)),
                                 requires_grad=True).to(device)
                pred_hat, pred_tag = self.D(x_hat)
                gradients = grad(outputs=pred_hat,
                                 inputs=x_hat,
                                 grad_outputs=torch.ones(
                                     pred_hat.size()).to(device),
                                 create_graph=True,
                                 retain_graph=True,
                                 only_inputs=True)[0].view(x_hat.size(0), -1)
                gradient_penalty = lambda_gp * (
                    (gradients.norm(2, dim=1) - 1)**2).mean()
                gradient_penalty.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator gradient penalty'] = float(
                            gradient_penalty)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(1.0)

                # 2.2. calc loss
                label_loss_g = self.label_criterion(fake_label_p, label)
                tag_loss_g = self.tag_criterion(fake_tag_p, fake_tag)
                loss_g = label_loss_g * lambda_adv / 2.0 + tag_loss_g * lambda_adv / 2.0
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % verbose_T == 0:
                    vutils.save_image(
                        avatar_img.data.view(batch_size, 3, avatar_img.size(2),
                                             avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = fake_generator()
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, avatar_img.size(2),
                                           avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            self.epoch += 1
class Trainer():
    def __init__(self, opt):
        self.device = torch.device('cuda')
        self.opt = opt
        self.G = Generator(self.opt['network_G']).to(self.device)
        util.init_weights(self.G, init_type='kaiming', scale=0.1)
        if self.opt['path']['pretrain_G']:
            self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']),
                                   strict=True)
        self.D = Discriminator(self.opt['network_D']).to(self.device)
        util.init_weights(self.D, init_type='kaiming', scale=1)
        self.FE = VGGFeatureExtractor().to(self.device)
        self.G.train()
        self.D.train()
        self.FE.eval()

        self.log_dict = OrderedDict()

        self.optim_params = [
            v for k, v in self.G.named_parameters() if v.requires_grad
        ]
        self.opt_G = torch.optim.Adam(self.optim_params,
                                      lr=self.opt['train']['lr_G'],
                                      betas=(self.opt['train']['b1_G'],
                                             self.opt['train']['b2_G']))
        self.opt_D = torch.optim.Adam(self.D.parameters(),
                                      lr=self.opt['train']['lr_D'],
                                      betas=(self.opt['train']['b1_D'],
                                             self.opt['train']['b2_D']))

        self.optimizers = [self.opt_G, self.opt_D]
        self.schedulers = [
            lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'],
                                     self.opt['train']['lr_gamma'])
            for optimizer in self.optimizers
        ]

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()

    def get_current_log(self):
        return self.log_dict

    def get_current_learning_rate(self):
        return self.schedulers[0].get_lr()[0]

    def load_model(self, step, strict=True):
        self.G.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_G.pth"),
                               strict=strict)
        self.D.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_D.pth"),
                               strict=strict)

    def resume_training(self, resume_state):
        '''Resume the optimizers and schedulers for training'''

        resume_optimizers = resume_state['optimizers']
        resume_schedulers = resume_state['schedulers']
        assert len(resume_optimizers) == len(
            self.optimizers), 'Wrong lengths of optimizers'
        assert len(resume_schedulers) == len(
            self.schedulers), 'Wrong lengths of schedulers'
        for i, o in enumerate(resume_optimizers):
            self.optimizers[i].load_state_dict(o)
        for i, s in enumerate(resume_schedulers):
            self.schedulers[i].load_state_dict(s)

    def save_network(self, network, network_label, iter_step):

        util.mkdir(self.opt['path']['checkpoints']['models'])
        save_filename = '{}_{}.pth'.format(iter_step, network_label)
        save_path = os.path.join(self.opt['path']['checkpoints']['models'],
                                 save_filename)

        if isinstance(network, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path)

    def save_model(self, epoch, current_step):
        self.save_network(self.G, 'G', current_step)
        self.save_network(self.D, 'D', current_step)
        self.save_training_state(epoch, current_step)

    def save_training_state(self, epoch, iter_step):
        '''Saves training state during training, which will be used for resuming'''
        state = {
            'epoch': epoch,
            'iter': iter_step,
            'schedulers': [],
            'optimizers': []
        }
        for s in self.schedulers:
            state['schedulers'].append(s.state_dict())
        for o in self.optimizers:
            state['optimizers'].append(o.state_dict())
        save_filename = '{}.state'.format(iter_step)
        util.mkdir(self.opt['path']['checkpoints']['states'])
        save_path = os.path.join(self.opt['path']['checkpoints']['states'],
                                 save_filename)
        torch.save(state, save_path)

    def train(self, train_batch, step):

        self.lr = train_batch['LR'].to(self.device)
        self.hr = train_batch['HR'].to(self.device)

        for p in self.D.parameters():
            p.requires_grad = False

        self.opt_G.zero_grad()

        self.sr = self.G(self.lr)

        l_g_total = 0
        # pixel loss
        l_g_pix = self.opt['train']['wt_pix'] * cri_pix(self.sr, self.hr)
        l_g_total += l_g_pix

        # feature loss
        real_fea = self.FE(self.hr).detach()
        fake_fea = self.FE(self.sr)
        l_g_fea = self.opt['train']['wt_fea'] * cri_fea(fake_fea, real_fea)
        l_g_total += l_g_fea

        # ragan loss

        pred_g_fake = self.D(self.sr)
        pred_d_real = self.D(self.hr).detach()

        l_g_gan = self.opt['train']['wt_gan'] * (
            cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
            cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
        l_g_total += l_g_gan

        l_g_total.backward()
        self.opt_G.step()

        # D
        for p in self.D.parameters():
            p.requires_grad = True

        self.opt_D.zero_grad()

        l_d_total = 0

        pred_d_real = self.D(self.hr)
        pred_d_fake = self.D(self.sr.detach())  # detach to avoid BP to G

        l_d_real = cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
        l_d_fake = cri_gan(pred_d_fake - torch.mean(pred_d_real), False)

        l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.opt_D.step()

        # set log
        # G
        self.log_dict['l_g_pix'] = l_g_pix.item()
        self.log_dict['l_g_fea'] = l_g_fea.item()
        self.log_dict['l_g_gan'] = l_g_gan.item()

        # D
        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()

        # D outputs
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def validate(self, val_batch, current_step):
        avg_psnr = 0.0
        avg_ssim = 0.0
        idx = 0
        for _, val_data in enumerate(val_batch):
            idx += 1
            img_name = os.path.splitext(
                os.path.basename(val_data['LR_path'][0]))[0]
            img_dir = os.path.join(
                self.opt['path']['checkpoints']['val_image_dir'], img_name)
            util.mkdir(img_dir)

            self.val_lr = val_data['LR'].to(self.device)
            self.val_hr = val_data['HR'].to(self.device)

            self.G.eval()
            with torch.no_grad():
                self.val_sr = self.G(self.val_lr)
            self.G.train()

            val_LR = self.val_lr.detach()[0].float().cpu()
            val_SR = self.val_sr.detach()[0].float().cpu()
            val_HR = self.val_hr.detach()[0].float().cpu()

            sr_img = util.tensor2img(val_SR)  # uint8
            gt_img = util.tensor2img(val_HR)  # uint8

            # Save SR images for reference
            save_img_path = os.path.join(
                img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
            cv2.imwrite(save_img_path, sr_img)

            # calculate PSNR
            crop_size = 4
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            cropped_sr_img = sr_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            cropped_gt_img = gt_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255)
            avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255)

        avg_psnr = avg_psnr / idx
        avg_ssim = avg_ssim / idx
        return avg_psnr, avg_ssim
示例#3
0
    discriminator = Discriminator(in_channels=3, pretrained=GEN_PRETRAINED, weights_path=disc_weights_path)
    discriminator = discriminator.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(generator)
        cudnn.benchmark = True

    # Loss functions
    criterion = nn.MSELoss()
    adversarial_loss = nn.BCELoss()
    perceptual_loss = vgg16().to(device)
    # categorical_loss = torch.nn.CrossEntropyLoss()
    # continuous_loss = torch.nn.MSELoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=LR)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)

    if DATASET.lower() == 'ntu':
        data_root_dir, train_split, test_split, param_file, gen_weight_file, disc_weight_file = ntu_config()

        # data
        trainset = NTUDataset(root_dir=data_root_dir, data_file=train_split, param_file=param_file,
                              resize_height=HEIGHT, resize_width=WIDTH,
                              clip_len=FRAMES, skip_len=SKIP_LEN,
                              random_all=RANDOM_ALL, precrop=PRECROP)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

        testset = NTUDataset(root_dir=data_root_dir, data_file=test_split, param_file=param_file,
                             resize_height=HEIGHT, resize_width=WIDTH,
                             clip_len=FRAMES, skip_len=SKIP_LEN,
                             random_all=RANDOM_ALL, precrop=PRECROP)
示例#4
0
# Initialize Generator
generator = Generator(args.GAN_TYPE, args.ZDIM, args.NUM_CLASSES)
generator.apply(weights_init)
generator.to(device)
print(generator)

# Initialize Discriminator
discriminator = Discriminator(args.GAN_TYPE, args.NUM_CLASSES)
discriminator.apply(weights_init)
discriminator.to(device)
print(discriminator)

# Initialize loss function and optimizer
criterionLabel = nn.BCELoss()
criterionClass = nn.CrossEntropyLoss()
optimizerD = Adam(discriminator.parameters(), lr=args.LR, betas=(0.5, 0.999))
optimizerG = Adam(generator.parameters(), lr=args.LR, betas=(0.5, 0.999))

# Prepare the noise for evaluation during training phase
fixedNoise = torch.FloatTensor(args.BATCH_SIZE, args.ZDIM, 1, 1).normal_(0, 1)
if args.GAN_TYPE in ["CGAN", "ACGAN"]:
    fixedClass = F.one_hot(torch.LongTensor([i % args.NUM_CLASSES for i in range(args.BATCH_SIZE)]), num_classes=args.NUM_CLASSES)
    fixedConstraint = fixedClass.unsqueeze(-1).unsqueeze(-1)
    fixed_z = torch.cat((fixedNoise, fixedConstraint), 1)
else:
    fixed_z = fixedNoise
fixed_z = fixed_z.to(device)

# Prepare the data generator with proper data transformation
if args.AUGMENTED:
    trainDataset = MNIST(augmented_train_transform(), "train")
示例#5
0
    gen_path = args.from_pretrained_gen
    if gen_path:
        gen_net.load_state_dict(torch.load(gen_path))
    dis_path = args.from_pretrained_dis
    if dis_path:
        dis_net.load_state_dict(torch.load(dis_path))

    perceptual_loss = PerceptualLoss(device=device)
    mse_loss = nn.MSELoss()
    beta1 = 0.9

    opt_gen = optim.Adam(gen_net.parameters(),
                         lr=args.lr,
                         betas=(beta1, 0.999))
    opt_dis = optim.Adam(dis_net.parameters(),
                         lr=args.lr,
                         betas=(beta1, 0.999))
    opt_gen_path = args.from_pretrained_optimizer_gen
    opt_dis_path = args.from_pretrained_optimizer_dis
    if opt_gen_path:
        opt_gen.load_state_dict(torch.load(opt_gen_path))
    if opt_dis_path:
        opt_dis.load_state_dict(torch.load(opt_dis_path))
    hr_test, lr_test = next(iter(test_dataloader))
    vutils.save_image(hr_test, f"{results_dir}/hr.png", normalize=True)
    vutils.save_image(lr_test, f"{results_dir}/lr.png", normalize=True)
    with open(f"{models_dir}/args.yml", "w") as f:
        yaml.dump(args, f)
    for epoch in range(args.epochs):
        training_bar = tqdm(train_dataloader)
示例#6
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])

            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.a_D = alexnet.alexnet(num_classes=tag_size).to(device)
        self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(),
                                              lr=learning_rate,
                                              betas=(beta_1, .999))
        # self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        # self.tag_criterion = nn.BCEWithLogitsLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = utils.read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        if torch.cuda.is_available():
            checkpoint = torch.load(new_model_path)
        else:
            checkpoint = torch.load(
                new_model_path,
                map_location='cuda' if torch.cuda.is_available() else 'cpu')
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (food_tag, food_img) in enumerate(self.data_loader):
                iteration += 1
                if food_img.shape[0] != batch_size:
                    logging.warn('Batch size not satisfied. Ignoring.')
                    continue
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration

                food_img = Variable(food_img).to(device)
                # 0. training assistant D
                self.a_D.zero_grad()
                a_D_feat = self.a_D(food_img)

                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p = self.D(food_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                # real_label_loss = self.label_criterion(label_p, label)
                real_label_loss = F.binary_cross_entropy(label_p, label)
                real_loss_sum = real_label_loss
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                # fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_label_loss = F.binary_cross_entropy(fake_label_p, label)
                # TODO:
                fake_loss_sum = fake_label_loss
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        print('predicted fake label: {}'.format(fake_label_p))
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p = self.D(fake_img)
                label.data.fill_(1.0)

                a_D_feat = self.a_D(fake_img)
                feat_loss = F.binary_cross_entropy(a_D_feat, fake_tag)

                # 2.2. calc loss
                # label_loss_g = self.label_criterion(fake_label_p, label)
                label_loss_g = F.binary_cross_entropy(fake_label_p, label)
                loss_g = label_loss_g
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % 10000 == 0:
                    torch.save(
                        {
                            'epoch': self.epoch,
                            'D': self.D.state_dict(),
                            'G': self.G.state_dict(),
                            'optimizer_D': self.optimizer_D.state_dict(),
                            'optimizer_G': self.optimizer_G.state_dict(),
                        },
                        '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                      str(iteration).zfill(8)))
                    logger.info('Checkpoint saved in: {}'.format(
                        '{}/checkpoint_{}.tar'.format(
                            model_dump_path,
                            str(iteration).zfill(8))))

                if iteration % verbose_T == 0:
                    vutils.save_image(
                        food_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = utils.fake_generator(
                        batch_size, noise_size, device)
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))
            self.epoch += 1
示例#7
0
    run_name = 'correlation-GAN_{}'.format(config.version)
    wandb.init(name=run_name,
               dir=config.checkpoint_dir,
               notes=config.description)
    wandb.config.update(config.__dict__)

    device = torch.device('cuda')

    use_dropout = [True, True, False]
    drop_prob = [0.5, 0.5, 0.5]
    use_ac_func = [True, True, False]
    activation = 'relu'
    latent_dim = 10

    gen_fc_layers = [latent_dim, 16, 32, 2]
    generator = Generator(gen_fc_layers, use_dropout, drop_prob, use_ac_func,
                          activation).to(device)

    disc_fc_layers = [2, 32, 16, 1]
    discriminator = Discriminator(disc_fc_layers, use_dropout, drop_prob,
                                  use_ac_func, activation).to(device)

    wandb.watch([generator, discriminator])

    g_optimizer = Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
    d_optimizer = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))

    wgan_gp = WGAN_GP(config, generator, discriminator, g_optimizer,
                      d_optimizer, latent_shape)
    wgan_gp.train(dataloader, 200)