示例#1
0
    def test(self):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('Evaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)), False)
        self.model.eval()

        # We can use custom forward function
        def _test_forward(x, scale):
            if self.args.self_ensemble:
                return utility.x8_forward(x, self.model, self.args.precision)
            elif self.args.chop_forward:
                return utility.chop_forward(x, self.model, scale)
            else:
                return self.model(x)

        timer_test = utility.timer()
        set_name = self.args.data_test
        for idx_scale, scale in enumerate(self.scale):
            eval_acc = 0
            self._scale_change(idx_scale, self.loader_test)
            for idx_img, (lr, hr, _) in enumerate(self.loader_test):
                no_eval = isinstance(hr[0], torch._six.string_classes)
                if no_eval:
                    lr = self.prepare([lr], volatile=True)[0]
                    filename = hr[0]
                else:
                    lr, hr = self.prepare([lr, hr], volatile=True)
                    filename = idx_img + 1

                rgb_range = self.args.rgb_range
                sr = _test_forward(lr, scale)
                sr = utility.quantize(sr, rgb_range)

                if no_eval:
                    save_list = [sr]
                else:
                    eval_acc += utility.calc_PSNR(sr, hr.div(rgb_range),
                                                  set_name, scale)
                    save_list = [sr, lr.div(rgb_range), hr.div(rgb_range)]

                if self.args.save_results:
                    self.ckp.save_results(filename, save_list, scale)

            self.ckp.log_test[-1, idx_scale] = eval_acc / len(self.loader_test)
            best = self.ckp.log_test.max(0)
            performance = 'PSNR: {:.3f}'.format(self.ckp.log_test[-1,
                                                                  idx_scale])
            self.ckp.write_log(
                '[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
                    set_name, scale, performance, best[0][idx_scale],
                    best[1][idx_scale] + 1))

        is_best = (best[1][0] + 1 == epoch)
        # self.ckp.write_log(
        # 'Time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        # )
        self.ckp.save(self, epoch, is_best=is_best)
示例#2
0
    def test(self, test_only=False, starttime=0):
        epoch = self.scheduler.last_epoch + 1
        self.ckp.write_log('Evaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)), False)
        self.model.eval()

        # We can use custom forward function
        def _test_forward(x, scale):
            if self.args.self_ensemble:
                return utility.x8_forward(x, self.model, self.args.precision)
            elif self.args.chop_forward:
                return utility.chop_forward(x, self.model, scale)
            else:
                return self.model(x)

        set_name = self.args.data_test
        for idx_scale, scale in enumerate(self.scale):
            eval_acc = 0
            self._scale_change(idx_scale, self.loader_test)
            for idx_img, (lr, hr, _) in enumerate(self.loader_test):
                no_eval = isinstance(hr[0], torch._six.string_classes)
                if no_eval:
                    lr = self.prepare([lr])[0]
                    filename = hr[0]
                else:
                    lr, hr = self.prepare([lr, hr])
                    filename = idx_img + 1

                rgb_range = self.args.rgb_range
                timer_test = utility.timer()
                sr = _test_forward(lr, scale)
                self.test_time[1] += timer_test.toc()
                self.test_time[0] += 1
                sr = utility.quantize(sr, rgb_range)

                if no_eval:
                    save_list = [sr]
                else:
                    if self.args.loss == '1*SSIM':
                        eval_acc += pytorch_ssim.ssim(
                            sr, hr.div(rgb_range)).item()
                    else:
                        eval_acc += utility.calc_PSNR(sr, hr.div(rgb_range),
                                                      set_name, scale)
                    save_list = [sr, lr.div(rgb_range), hr.div(rgb_range)]

                if self.args.save_results:
                    self.ckp.save_results(filename, save_list, scale)

            self.ckp.log_test[-1, idx_scale] = eval_acc / len(self.loader_test)
            best = self.ckp.log_test.max(0)
            if self.args.loss == '1*SSIM':
                performance = 'SSIM: {:.3f}'.format(
                    self.ckp.log_test[-1, idx_scale])
            else:
                performance = 'PSNR: {:.3f}'.format(
                    self.ckp.log_test[-1, idx_scale])
            self.ckp.write_log(
                '[{} x{}]\t{} (Best: {:.3f} from epoch {})'.format(
                    set_name, scale, performance, best[0][idx_scale],
                    best[1][idx_scale] + 1))

        is_best = (best[1][0] + 1 == epoch)
        if test_only:
            self.ckp.write_log(
                'Total time: {:.3f}s\r\nAvg. time: {:.3f}s\n'.format(
                    self.test_time[1], self.test_time[1] / self.test_time[0]),
                refresh=True)
        elif starttime != 0:
            now = datetime.datetime.now()
            elapsed = now - starttime
            est = now + (elapsed / epoch) * (self.args.epochs - epoch)
            self.ckp.write_log("Elapsed: {}\n".format(str(elapsed)))
            print('Will finish: {}\n'.format(
                est.strftime('%d-%m-%Y-%H:%M:%S')))

        self.ckp.save(self, epoch, is_best=is_best)