Пример #1
0
    def train_epoch(self):
        losses = Metric()
        pixel_loss = Metric()
        feat_loss = Metric()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()
        # Make sure the criterion is also set to the correct state
        self.criterion.train()

        for i, (lr, hr) in enumerate(pb):
            # Note that the lr here means low-resolution (images)
            # rather than learning rate
            lr, hr = lr.cuda(), hr.cuda()
            sr = self.model(lr)

            loss, pl, fl = self.criterion(sr, hr)

            losses.update(loss.data, n=self.batch_size)
            pixel_loss.update(pl.data, n=self.batch_size)
            feat_loss.update(fl.data, n=self.batch_size)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            desc = "[{}/{}] Loss {loss.val:.4f} ({loss.avg:.4f}) " \
                    "PL {pixel.val:.4f} ({pixel.avg:.4f}) " \
                    "FL {feat.val:.6f} ({feat.avg:.6f})"\
                .format(i+1, len_train, loss=losses,
                    pixel=pixel_loss, feat=feat_loss)
            pb.set_description(desc)
            self.logger.dump(desc)
Пример #2
0
    def train_epoch(self):
        losses = Metric()
        pixel_loss = Metric()
        feat_loss = Metric()
        discr_loss = Metric()
        len_train = len(self.train_loader)
        pb = tqdm(self.train_loader)

        self.model.train()
        # Make sure the criterion is also set to the correct state
        self.criterion.train()

        for i, (lr, hr) in enumerate(pb):
            # Note that the lr here means low-resolution (images)
            # rather than learning rate
            lr, hr = lr.cuda(), hr.cuda()
            sr = self.model(lr)

            if i % 1 == 0:
                with self.criterion.iqa_loss.learner():
                    # Train the IQA model
                    dl = self.discr_learn(hr, hr, 0.0)  # Good-quality images
                    dl += self.discr_learn(sr.detach(),
                                           hr)  # Bad-quality images
                    dl /= 2
                    discr_loss.update(dl, n=self.batch_size)

            # Train the SR model
            loss, pl, fl = self.criterion(sr, hr)
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1)
            self.optimizer.step()

            # Update data
            losses.update(loss.data, n=self.batch_size)
            pixel_loss.update(pl.data, n=self.batch_size)
            feat_loss.update(fl.data, n=self.batch_size)

            # Log for this mini-batch
            desc = "[{}/{}] Loss {loss.val:.4f} ({loss.avg:.4f}) " \
                    "DL {discr.val:.4f} ({discr.avg:.4f}) " \
                    "PL {pixel.val:.4f} ({pixel.avg:.4f}) " \
                    "FL {feat.val:.6f} ({feat.avg:.6f})"\
                .format(i+1, len_train, loss=losses,
                        discr=discr_loss,
                        pixel=pixel_loss, feat=feat_loss)
            pb.set_description(desc)
            self.logger.dump(desc)
Пример #3
0
    def validate_epoch(self, epoch=0, store=False):
        self.logger.show_nl("Epoch: [{0}]".format(epoch))
        losses = Metric(self.criterion)
        ssim = ShavedSSIM(self.scale)
        psnr = ShavedPSNR(self.scale)
        len_val = len(self.val_loader)
        pb = tqdm(self.val_loader)
        to_image = self.dataset.tensor_to_image

        self.model.eval()
        self.criterion.eval()

        with torch.no_grad():
            for i, (name, lr, hr) in enumerate(pb):
                if self.phase == 'train' and i >= 16:
                    # Do not validate all images on training phase
                    pb.close()
                    self.logger.warning("validation ends early")
                    break

                lr, hr = lr.unsqueeze(0).cuda(), hr.unsqueeze(0).cuda()

                sr = self.model(lr)

                losses.update(sr, hr)

                lr = to_image(lr.squeeze(0), 'lr')
                sr = to_image(sr.squeeze(0))
                hr = to_image(hr.squeeze(0))

                psnr.update(sr, hr)
                ssim.update(sr, hr)

                pb.set_description(
                    "[{}/{}]"
                    "Loss {loss.val:.4f} ({loss.avg:.4f}) "
                    "PSNR {psnr.val:.4f} ({psnr.avg:.4f}) "
                    "SSIM {ssim.val:.4f} ({ssim.avg:.4f})".format(i + 1,
                                                                  len_val,
                                                                  loss=losses,
                                                                  psnr=psnr,
                                                                  ssim=ssim))

                self.logger.dump("[{}/{}]"
                                 "{} "
                                 "Loss {loss.val:.4f} ({loss.avg:.4f}) "
                                 "PSNR {psnr.val:.4f} ({psnr.avg:.4f}) "
                                 "SSIM {ssim.val:.4f} ({ssim.avg:.4f})".format(
                                     i + 1,
                                     len_val,
                                     name,
                                     loss=losses,
                                     psnr=psnr,
                                     ssim=ssim))

                if store:
                    # lr_name = self.path_ctrl.add_suffix(name, suffix='lr', underline=True)
                    # hr_name = self.path_ctrl.add_suffix(name, suffix='hr', underline=True)
                    sr_name = self.path_ctrl.add_suffix(name,
                                                        suffix='sr',
                                                        underline=True)

                    # self.save_image(lr_name, lr, epoch)
                    # self.save_image(hr_name, hr, epoch)
                    self.save_image(sr_name, sr, epoch)

        return psnr.avg