Ejemplo n.º 1
0
def get_prediction(model_checkpoint, resnet_type):
    # models
    F = FeatureExtractor(resnet=resnet_type).to(device)
    C = LabelPredictor(resnet=resnet_type).to(device)

    checkpoint = torch.load(model_checkpoint)
    F.load_state_dict(checkpoint['feature_extractor'])
    C.load_state_dict(checkpoint['label_predictor'])

    # predict
    F.eval()
    C.eval()
    result = []
    for i, (data, _) in enumerate(target_loader):
        print(i + 1, len(target_loader), end='\r')
        data = data.to(device)

        logits = C(F(data))

        x = torch.argmax(logits, dim=1).cpu().detach().numpy()
        result.append(x)

    # delete model
    del F
    del C
    torch.cuda.empty_cache()

    return np.concatenate(result)
Ejemplo n.º 2
0
 def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False):
     generator = GeneratorRRDB(self.opt.channels,
                               filters=64,
                               num_res_blocks=self.opt.residual_blocks,
                               use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(
                                   self.device, non_blocking=True)
     discriminator = Discriminator(
         input_shape=(self.opt.channels, *hr_shape),
         use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device,
                                                   non_blocking=True)
     feature_extractor = FeatureExtractor().to(self.device,
                                               non_blocking=True)
     # Set feature extractor to inference mode
     feature_extractor.eval()
     return discriminator, feature_extractor, generator
Ejemplo n.º 3
0
    return loss


def get_sync_loss(mel, g):
    g = g[:, :, :, g.size(3) // 2:]
    g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
    # B, 3 * T, H//2, W
    a, v = syncnet(mel, g)
    y = torch.ones(g.size(0), 1).float().to(device)
    return cosine_loss(a, v, y)


recon_loss = nn.L1Loss()
feature_extractor = FeatureExtractor()
feature_extractor.eval()


# --------- Add content loss here ---------------
def get_content_loss(g, gt):

    gen_feautres = feature_extractor(g)
    real_features = feature_extractor(gt)
    loss_content = recon_loss(gen_feautres, real_features.detach())

    return loss_content


def train(device,
          model,
          disc,
Ejemplo n.º 4
0
class ESRGAN():
    def __init__(self, opt):
        self.opt = opt
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        hr_shape = (self.opt.hr_height, self.opt.hr_width)
        self._set_model(device, hr_shape)

    def _set_model(self, device, hr_shape):
        # Initialize generator and discriminator
        self.generator = GeneratorRRDB(
            opt.channels, filters=64,
            num_res_blocks=opt.residual_blocks).to(device)
        self.discriminator = Discriminator(input_shape=(opt.channels,
                                                        *hr_shape)).to(device)
        self.feature_extractor = FeatureExtractor().to(device)

        # Set feature extractor to inference mode
        self.feature_extractor.eval()

        # Losses
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
        self.criterion_content = torch.nn.L1Loss().to(device)
        self.criterion_pixel = torch.nn.L1Loss().to(device)

    def _set_param(self):
        for key, value in vars(opt).items():
            mlflow.log_param(key, value)

    def _load_weigth(self):
        if opt.epoch != 0:
            # Load pretrained models
            load_g_weight_path = osp.join(weight_save_dir,
                                          "generator_%d.pth" % opt.epoch)
            load_d_weight_path = osp.join(weight_save_dir,
                                          "discriminator_%d.pth" % opt.epoch)

            self.generator.load_state_dict(torch.load(load_g_weight_path))
            self.discriminator.load_state_dict(torch.load(load_d_weight_path))

        # Optimizers
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))

    # ----------
    #  Training
    # ----------
    def train(self, dataloader, opt):
        for epoch in range(opt.epoch + 1, opt.n_epochs + 1):
            for batch_num, imgs in enumerate(dataloader):
                Tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
                ) else torch.Tensor
                batches_done = (epoch - 1) * len(dataloader) + batch_num

                # Configure model input
                imgs_lr = Variable(imgs["lr"].type(Tensor))
                imgs_hr = Variable(imgs["hr"].type(Tensor))

                # Adversarial ground truths
                valid = Variable(Tensor(
                    np.ones((imgs_lr.size(0), *discriminator.output_shape))),
                                 requires_grad=False)
                fake = Variable(Tensor(
                    np.zeros((imgs_lr.size(0), *discriminator.output_shape))),
                                requires_grad=False)

                # ------------------
                #  Train Generators
                # ------------------

                optimizer_G.zero_grad()

                # Generate a high resolution image from low resolution input
                gen_hr = generator(imgs_lr)

                # Measure pixel-wise loss against ground truth
                loss_pixel = criterion_pixel(gen_hr, imgs_hr)

                # Warm-up (pixel-wise loss only)
                if batches_done <= opt.warmup_batches:
                    loss_pixel.backward()
                    optimizer_G.step()
                    log_info = "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format(
                        epoch, opt.n_epochs, batch_num, len(dataloader),
                        loss_pixel.item())

                    sys.stdout.write("\r{}".format(log_info))
                    sys.stdout.flush()

                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)
                else:
                    # Extract validity predictions from discriminator
                    pred_real = discriminator(imgs_hr).detach()
                    pred_fake = discriminator(gen_hr)

                    # Adversarial loss (relativistic average GAN)
                    loss_GAN = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), valid)

                    # Content loss
                    gen_features = feature_extractor(gen_hr)
                    real_features = feature_extractor(imgs_hr).detach()
                    loss_content = criterion_content(gen_features,
                                                     real_features)

                    # Total generator loss
                    loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel

                    loss_G.backward()
                    optimizer_G.step()

                    # ---------------------
                    #  Train Discriminator
                    # ---------------------

                    optimizer_D.zero_grad()

                    pred_real = discriminator(imgs_hr)
                    pred_fake = discriminator(gen_hr.detach())

                    # Adversarial loss for real and fake images (relativistic average GAN)
                    loss_real = criterion_GAN(
                        pred_real - pred_fake.mean(0, keepdim=True), valid)
                    loss_fake = criterion_GAN(
                        pred_fake - pred_real.mean(0, keepdim=True), fake)

                    # Total loss
                    loss_D = (loss_real + loss_fake) / 2

                    loss_D.backward()
                    optimizer_D.step()

                    # --------------
                    #  Log Progress
                    # --------------

                    log_info = "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}]".format(
                        epoch,
                        opt.n_epochs,
                        batch_num,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_content.item(),
                        loss_GAN.item(),
                        loss_pixel.item(),
                    )

                    if batch_num == 1:
                        sys.stdout.write("\n{}".format(log_info))
                    else:
                        sys.stdout.write("\r{}".format(log_info))

                    sys.stdout.flush()

                    # import pdb; pdb.set_trace()

                    if batches_done % opt.sample_interval == 0:
                        # Save image grid with upsampled inputs and ESRGAN outputs
                        imgs_lr = nn.functional.interpolate(imgs_lr,
                                                            scale_factor=4)
                        img_grid = denormalize(torch.cat((imgs_lr, gen_hr),
                                                         -1))

                        image_batch_save_dir = osp.join(
                            image_train_save_dir, '{:07}'.format(batches_done))
                        os.makedirs(osp.join(image_batch_save_dir, "hr_image"),
                                    exist_ok=True)
                        save_image(img_grid,
                                   osp.join(image_batch_save_dir, "hr_image",
                                            "%d.png" % batches_done),
                                   nrow=1,
                                   normalize=False)

                    if batches_done % opt.checkpoint_interval == 0:
                        # Save model checkpoints
                        torch.save(
                            generator.state_dict(),
                            osp.join(weight_save_dir,
                                     "generator_%d.pth" % epoch))
                        torch.save(
                            discriminator.state_dict(),
                            osp.join(weight_save_dir,
                                     "discriminator_%d.pth" % epoch))

                    mlflow.log_metric('train_{}'.format('loss_D'),
                                      loss_D.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_G'),
                                      loss_G.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_content'),
                                      loss_content.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_GAN'),
                                      loss_GAN.item(),
                                      step=batches_done)
                    mlflow.log_metric('train_{}'.format('loss_pixel'),
                                      loss_pixel.item(),
                                      step=batches_done)