Example #1
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = tqdm(total=len(dataloader), unit='image')

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')
                imwrite(sr_img, save_img_path)

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    metric_type = opt_.pop('type')
                    self.metric_results[name] += getattr(
                        metric_module, metric_type)(sr_img, gt_img, **opt_)
            pbar.update(1)
            pbar.set_description(f'Test {img_name}')
        pbar.close()

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #2
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = ProgressBar(len(dataloader))

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            # del self.lq
            # del self.output
            # torch.cuda.empty_cache()
            # print(save_img, 'gggggggggggggggggggg')
            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')
                mmcv.imwrite(sr_img, save_img_path)
                # print('save to /home/wei/gy/EDVR/flow_save_160/offset.npy')
                # np.save('/home/wei/gy/EDVR/flow_save_160/offset.npy', visual['flow'])
                # np.save('/home/wei/gy/EDVR/flow_save_160/mask.npy', visual['mask'])
            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    metric_type = opt_.pop('type')
                    self.metric_results[name] += getattr(
                        metric_module, metric_type)(sr_img, gt_img, **opt_)
            pbar.update(f'Test {img_name}')

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #3
0
    def enhance(self,
                img,
                has_aligned=False,
                only_center_face=False,
                paste_back=True):
        self.face_helper.clean_all()

        if has_aligned:  # the inputs are already aligned
            img = cv2.resize(img, (512, 512))
            self.face_helper.cropped_faces = [img]
        else:
            self.face_helper.read_image(img)
            # get face landmarks for each face
            self.face_helper.get_face_landmarks_5(
                only_center_face=only_center_face, eye_dist_threshold=5)
            # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
            # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
            # align and warp each face
            self.face_helper.align_warp_face()

        # face restoration
        for cropped_face in self.face_helper.cropped_faces:
            # prepare data
            cropped_face_t = img2tensor(cropped_face / 255.,
                                        bgr2rgb=True,
                                        float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                      inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)

            try:
                output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
                # convert to image
                restored_face = tensor2img(output.squeeze(0),
                                           rgb2bgr=True,
                                           min_max=(-1, 1))
            except RuntimeError as error:
                print(f'\tFailed inference for GFPGAN: {error}.')
                restored_face = cropped_face

            restored_face = restored_face.astype('uint8')
            self.face_helper.add_restored_face(restored_face)

        if not has_aligned and paste_back:
            # upsample the background
            if self.bg_upsampler is not None:
                # Now only support RealESRGAN for upsampling background
                bg_img = self.bg_upsampler.enhance(img,
                                                   outscale=self.upscale)[0]
            else:
                bg_img = None

            self.face_helper.get_inverse_affine(None)
            # paste each restored face to the input image
            restored_img = self.face_helper.paste_faces_to_input_image(
                upsample_img=bg_img)
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
        else:
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
Example #4
0
 def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
     assert dataloader is None, 'Validation dataloader should be None.'
     self.test()
     result = tensor2img(self.output, min_max=(-1, 1))
     if self.opt['is_train']:
         save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png')
     else:
         save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
     imwrite(result, save_img_path)
     # add sample images to tb_logger
     result = (result / 255.).astype(np.float32)
     result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
     if tb_logger is not None:
         tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
Example #5
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics:
            if not hasattr(self,
                           'metric_results'):  # only execute in the first run
                self.metric_results = {}
                num_frame_each_folder = Counter(dataset.data_info['folder'])
                for folder, num_frame in num_frame_each_folder.items():
                    self.metric_results[folder] = torch.zeros(
                        num_frame,
                        len(self.opt['val']['metrics']),
                        dtype=torch.float32,
                        device='cuda')
            # initialize the best metric results
            self._initialize_best_metric_results(dataset_name)
        # zero self.metric_results
        rank, world_size = get_dist_info()
        if with_metrics:
            for _, tensor in self.metric_results.items():
                tensor.zero_()

        metric_data = dict()
        num_folders = len(dataset)
        num_pad = (world_size - (num_folders % world_size)) % world_size
        if rank == 0:
            pbar = tqdm(total=len(dataset), unit='folder')
        # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
        # (To avoid wait-dead)
        for i in range(rank, num_folders + num_pad, world_size):
            idx = min(i, num_folders - 1)
            val_data = dataset[idx]
            folder = val_data['folder']

            # compute outputs
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)
            self.feed_data(val_data)
            val_data['lq'].squeeze_(0)
            val_data['gt'].squeeze_(0)

            self.test()
            visuals = self.get_current_visuals()

            # tentative for out of GPU memory
            del self.lq
            del self.output
            if 'gt' in visuals:
                del self.gt
            torch.cuda.empty_cache()

            if self.center_frame_only:
                visuals['result'] = visuals['result'].unsqueeze(1)
                if 'gt' in visuals:
                    visuals['gt'] = visuals['gt'].unsqueeze(1)

            # evaluate
            if i < num_folders:
                for idx in range(visuals['result'].size(1)):
                    result = visuals['result'][0, idx, :, :, :]
                    result_img = tensor2img([result])  # uint8, bgr
                    metric_data['img'] = result_img
                    if 'gt' in visuals:
                        gt = visuals['gt'][0, idx, :, :, :]
                        gt_img = tensor2img([gt])  # uint8, bgr
                        metric_data['img2'] = gt_img

                    if save_img:
                        if self.opt['is_train']:
                            raise NotImplementedError(
                                'saving image is not supported during training.'
                            )
                        else:
                            if self.center_frame_only:  # vimeo-90k
                                clip_ = val_data['lq_path'].split('/')[-3]
                                seq_ = val_data['lq_path'].split('/')[-2]
                                name_ = f'{clip_}_{seq_}'
                                img_path = osp.join(
                                    self.opt['path']['visualization'],
                                    dataset_name, folder,
                                    f"{name_}_{self.opt['name']}.png")
                            else:  # others
                                img_path = osp.join(
                                    self.opt['path']['visualization'],
                                    dataset_name, folder,
                                    f"{idx:08d}_{self.opt['name']}.png")
                            # image name only for REDS dataset
                        imwrite(result_img, img_path)

                    # calculate metrics
                    if with_metrics:
                        for metric_idx, opt_ in enumerate(
                                self.opt['val']['metrics'].values()):
                            result = calculate_metric(metric_data, opt_)
                            self.metric_results[folder][idx,
                                                        metric_idx] += result

                # progress bar
                if rank == 0:
                    for _ in range(world_size):
                        pbar.update(1)
                        pbar.set_description(f'Folder: {folder}')

        if rank == 0:
            pbar.close()

        if with_metrics:
            if self.opt['dist']:
                # collect data among GPUs
                for _, tensor in self.metric_results.items():
                    dist.reduce(tensor, 0)
                dist.barrier()

            if rank == 0:
                self._log_validation_metric_values(current_iter, dataset_name,
                                                   tb_logger)
Example #6
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt["name"]
        with_metrics = self.opt["val"].get("metrics") is not None
        use_pbar = self.opt["val"].get("pbar", False)

        if with_metrics:
            if not hasattr(self,
                           "metric_results"):  # only execute in the first run
                self.metric_results = {
                    metric: 0
                    for metric in self.opt["val"]["metrics"].keys()
                }
            # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
            self._initialize_best_metric_results(dataset_name)
        # zero self.metric_results
        if with_metrics:
            self.metric_results = {metric: 0 for metric in self.metric_results}

        metric_data = dict()
        if use_pbar:
            pbar = tqdm(total=len(dataloader), unit="image")

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data["lq_path"][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals["result"]])
            metric_data["img"] = sr_img
            if "gt" in visuals:
                gt_img = tensor2img([visuals["gt"]])
                metric_data["img2"] = gt_img
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt["is_train"]:
                    save_img_path = osp.join(
                        self.opt["path"]["visualization"],
                        img_name,
                        f"{img_name}_{current_iter}.png",
                    )
                else:
                    if self.opt["val"]["suffix"]:
                        save_img_path = osp.join(
                            self.opt["path"]["visualization"],
                            dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png',
                        )
                    else:
                        save_img_path = osp.join(
                            self.opt["path"]["visualization"],
                            dataset_name,
                            f'{img_name}_{self.opt["name"]}.png',
                        )
                imwrite(sr_img, save_img_path)

            if with_metrics:
                # calculate metrics
                for name, opt_ in self.opt["val"]["metrics"].items():
                    self.metric_results[name] += calculate_metric(
                        metric_data, opt_)
            if use_pbar:
                pbar.update(1)
                pbar.set_description(f"Test {img_name}")
        if use_pbar:
            pbar.close()

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= idx + 1
                # update the best metric result
                self._update_best_metric_result(dataset_name, metric,
                                                self.metric_results[metric],
                                                current_iter)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #7
0
        # face restoration for each cropped face
        for idx, (cropped_face, landmarks) in enumerate(
                zip(cropped_faces, face_helper.all_landmarks_68)):
            if landmarks is None:
                print(f'Landmarks is None, skip cropped faces with idx {idx}.')
            else:
                # prepare data
                part_locations = get_part_location(landmarks)
                cropped_face = transforms.ToTensor()(cropped_face)
                cropped_face = transforms.Normalize(
                    (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
                cropped_face = cropped_face.unsqueeze(0).to(device)

                with torch.no_grad():
                    output = net(cropped_face, part_locations)
                    im = tensor2img(output, min_max=(-1, 1))
                    del output
                torch.cuda.empty_cache()
                path, ext = os.path.splitext(
                    os.path.join(save_restore_root, img_name))
                save_path = f'{path}_{idx:02d}{ext}'
                mmcv.imwrite(im, save_path)
                face_helper.add_restored_face(im)

        print('\tGenerate the final result ...')
        # paste each restored face to the input image
        face_helper.paste_faces_to_input_image(
            os.path.join(save_final_root, img_name))

        # clean all the intermediate results to process the next image
        face_helper.clean_all()
Example #8
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = ProgressBar(len(dataloader))

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                # gt_img = tensor2raw([visuals['gt']]) # replace for raw data.
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}.png')
                # np.save(save_img_path.replace('.png', '.npy'), sr_img) # replace for raw data.
                mmcv.imwrite(sr_img, save_img_path)
                # mmcv.imwrite(gt_img, save_img_path.replace('syn_val', 'gt'))

            save_npy = self.opt['val'].get('save_npy', None)
            if save_npy:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.npy')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.npy')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}.npy')

                np.save(save_img_path,
                        tensor2npy([visuals['result']
                                    ]))  # saving as .npy format.

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    metric_type = opt_.pop('type')
                    # replace for raw data.
                    # self.metric_results[name] += getattr(
                    #     metric_module, metric_type)(sr_img*255, gt_img*255, **opt_)

                    self.metric_results[name] += getattr(
                        metric_module, metric_type)(sr_img, gt_img, **opt_)
            pbar.update(f'Test {img_name}')

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #9
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics and not hasattr(self, 'metric_results'):
            self.metric_results = {}
            num_frame_each_folder = Counter(dataset.data_info['folder'])
            for folder, num_frame in num_frame_each_folder.items():
                self.metric_results[folder] = torch.zeros(
                    num_frame,
                    len(self.opt['val']['metrics']),
                    dtype=torch.float32,
                    device='cuda')

        rank, world_size = get_dist_info()
        for _, tensor in self.metric_results.items():
            tensor.zero_()
        # record all frames (border and center frames)
        if rank == 0:
            pbar = tqdm(total=len(dataset), unit='frame')
        for idx in range(rank, len(dataset), world_size):
            val_data = dataset[idx]
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)
            folder = val_data['folder']
            frame_idx, max_idx = val_data['idx'].split('/')
            lq_path = val_data['lq_path']

            self.feed_data(val_data)
            self.test()
            visuals = self.get_current_visuals()
            result_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    raise NotImplementedError(
                        'saving image is not supported during training.')
                else:
                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                    else:  # other datasets, e.g., REDS, Vid4
                        img_name = osp.splitext(osp.basename(lq_path))[0]

                    if self.opt['val']['suffix']:
                        print(
                            'self.opt[val][suffix](BasicSR/basicsr/models/video_base_model.py):',
                            self.opt['val']['suffix'])
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder, f'{img_name}_{self.opt["name"]}.png')
                imwrite(result_img, save_img_path)

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for metric_idx, opt_ in enumerate(opt_metric.values()):
                    metric_type = opt_.pop('type')
                    result = getattr(metric_module,
                                     metric_type)(result_img, gt_img, **opt_)
                    self.metric_results[folder][int(frame_idx),
                                                metric_idx] += result

            # progress bar
            if rank == 0:
                for _ in range(world_size):
                    pbar.update(1)
                    pbar.set_description(
                        f'Test {folder}:'
                        f'{int(frame_idx) + world_size}/{max_idx}')
        if rank == 0:
            pbar.close()

        if with_metrics:
            if self.opt['dist']:
                # collect data among GPUs
                for _, tensor in self.metric_results.items():
                    dist.reduce(tensor, 0)
                dist.barrier()
            else:
                pass  # assume use one gpu in non-dist testing

            if rank == 0:
                self._log_validation_metric_values(current_iter, dataset_name,
                                                   tb_logger)
Example #10
0
        x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)

        y, cb, cr = self.compress(x, factor=factor)
        recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
        recovered = recovered[:, :, 0:h, 0:w]
        return recovered


if __name__ == '__main__':
    import cv2

    from basicsr.utils import img2tensor, tensor2img

    img_gt = cv2.imread('test.png') / 255.

    # -------------- cv2 -------------- #
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
    _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
    img_lq = np.float32(cv2.imdecode(encimg, 1))
    cv2.imwrite('cv2_JPEG_20.png', img_lq)

    # -------------- DiffJPEG -------------- #
    jpeger = DiffJPEG(differentiable=False).cuda()
    img_gt = img2tensor(img_gt)
    img_gt = torch.stack([img_gt, img_gt]).cuda()
    quality = img_gt.new_tensor([20, 40])
    out = jpeger(img_gt, quality=quality)

    cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
    cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
Example #11
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = ProgressBar(len(dataloader))

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
#             del self.lq
#             del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')

                mmcv.imwrite(sr_img, save_img_path)
#                 np.save('/home/wei/exp/EDVR/flow_save_160/offset.npy', visual['flow'])
#                 np.save('/home/wei/exp/EDVR/flow_save_160/mask.npy', visual['mask'])
            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    if ('cosine' in name):
                        out_emb = visuals['embedding_out']
                        gt_emb = visuals['embedding_gt']
                        gt = gt_emb / torch.norm(gt_emb, 2, 1).reshape(
                            -1, 1).repeat(1, gt_emb.shape[1])
                        out = out_emb / torch.norm(out_emb, 2, 1).reshape(
                            -1, 1).repeat(1, out_emb.shape[1])
                        cos_similarity = torch.mean(torch.sum(gt * out, 1))
                        self.metric_results[name] += cos_similarity

                    else:
                        metric_type = opt_.pop('type')
                        self.metric_results[name] += getattr(
                            metric_module, metric_type)(sr_img, gt_img, **opt_)

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #12
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        use_pbar = self.opt['val'].get('pbar', False)

        if with_metrics:
            if not hasattr(self,
                           'metric_results'):  # only execute in the first run
                self.metric_results = {
                    metric: 0
                    for metric in self.opt['val']['metrics'].keys()
                }
            # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
            self._initialize_best_metric_results(dataset_name)
            # zero self.metric_results
            self.metric_results = {metric: 0 for metric in self.metric_results}

        metric_data = dict()
        if use_pbar:
            pbar = tqdm(total=len(dataloader), unit='image')

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
            metric_data['img'] = sr_img
            if hasattr(self, 'gt'):
                gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
                metric_data['img2'] = gt_img
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')
                imwrite(sr_img, save_img_path)

            if with_metrics:
                # calculate metrics
                for name, opt_ in self.opt['val']['metrics'].items():
                    self.metric_results[name] += calculate_metric(
                        metric_data, opt_)
            if use_pbar:
                pbar.update(1)
                pbar.set_description(f'Test {img_name}')
        if use_pbar:
            pbar.close()

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)
                # update the best metric result
                self._update_best_metric_result(dataset_name, metric,
                                                self.metric_results[metric],
                                                current_iter)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #13
0
        for idx, (cropped_face, landmarks) in enumerate(zip(cropped_faces, face_helper.all_landmarks_68)):
            if landmarks is None:
                print(f'Landmarks is None, skip cropped faces with idx {idx}.')
                # just copy the cropped faces to the restored faces
                restored_face = cropped_face
            else:
                # prepare data
                part_locations = get_part_location(landmarks)
                cropped_face = transforms.ToTensor()(cropped_face)
                cropped_face = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(cropped_face)
                cropped_face = cropped_face.unsqueeze(0).to(device)

                try:
                    with torch.no_grad():
                        output = net(cropped_face, part_locations)
                        restored_face = tensor2img(output, min_max=(-1, 1))
                    del output
                    torch.cuda.empty_cache()
                except Exception as e:
                    print(f'DFDNet inference fail: {e}')
                    restored_face = tensor2img(cropped_face, min_max=(-1, 1))

            path = os.path.splitext(os.path.join(save_restore_root, img_name))[0]
            save_path = f'{path}_{idx:02d}.png'
            imwrite(restored_face, save_path)
            face_helper.add_restored_face(restored_face)

        print('\tGenerate the final result ...')
        # paste each restored face to the input image
        face_helper.paste_faces_to_input_image(os.path.join(save_final_root, img_name))
Example #14
0
    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img,
                           save_h5):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = ProgressBar(len(dataloader))

        # Set up h5 file, if save
        if save_h5:
            h5_file = h5py.File(
                osp.join(self.opt['path']['visualization'], 'recon_img.hdf5'),
                'w')
            h5_dataset = h5_file.create_dataset('data',
                                                shape=(len(dataloader.dataset),
                                                       3, 256, 256),
                                                dtype=np.float32,
                                                fillvalue=0)
            counter = 0

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()

            # Save to h5 file, if save
            if save_h5:
                batch_size = val_data['lq'].shape[0]
                h5_dataset[counter:counter +
                           batch_size] = visuals['result'].numpy()
                counter += batch_size

            sr_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')
                mmcv.imwrite(sr_img, save_img_path)

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    metric_type = opt_.pop('type')
                    self.metric_results[name] += getattr(
                        metric_module, metric_type)(sr_img, gt_img, **opt_)
            pbar.update(f'Test {img_name}')

        if save_h5:
            h5_file.close()

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #15
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics and not hasattr(self, 'metric_results'):
            self.metric_results = {}
            num_frame_each_folder = Counter(dataset.data_info['folder'])
            for folder, num_frame in num_frame_each_folder.items():
                self.metric_results[folder] = torch.zeros(
                    num_frame,
                    len(self.opt['val']['metrics']),
                    dtype=torch.float32,
                    device='cuda')

        rank, world_size = get_dist_info()
        for _, tensor in self.metric_results.items():
            tensor.zero_()
        # record all frames (border and center frames)
        if rank == 0:
            pbar = ProgressBar(len(dataset))
        for idx in range(rank, len(dataset), world_size):
            val_data = dataset[idx]
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)
            folder = val_data['folder']
            frame_idx, max_idx = val_data['idx'].split('/')
            lq_path = val_data['lq_path']

            self.feed_data(val_data)
            self.test()
            visuals = self.get_current_visuals()

            result_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    raise NotImplementedError(
                        'saving image is not supported during training.')
                else:
                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                    else:  # other datasets, e.g., REDS, Vid4

                        img_name = osp.splitext(osp.basename(lq_path))[0]

                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], folder,
                            f'{img_name}.png')

                np_save_img_path = save_img_path.replace('png', 'npy')
                if not os.path.exists(
                        osp.join(self.opt['path']['visualization'], folder)):
                    os.makedirs(
                        osp.join(self.opt['path']['visualization'], folder))
                np.save(
                    np_save_img_path,
                    np.array([
                        visuals['embedding_gt'], visuals['embedding_out'],
                        visuals['embedding_center']
                    ]))
                mmcv.imwrite(result_img, save_img_path)
            split_result = lq_path.split('/')
            img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                        f'{split_result[-1].split(".")[0]}')
            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for metric_idx, opt_ in enumerate(opt_metric.values()):
                    out_emb = visuals['embedding_out']
                    gt_emb = visuals['embedding_gt']

                    gt = gt_emb / np.sqrt(np.sum(gt_emb**2, -1, keepdims=True))
                    out = out_emb / np.sqrt(
                        np.sum(out_emb**2, -1, keepdims=True))
                    cos_similarity = np.mean(np.sum(gt * out, -1))
                    result = cos_similarity

                    #                     self.metric_results[name] += cos_similarity
                    #                     metric_type = opt_.pop('type')
                    #                     result = getattr(metric_module,
                    #                                      metric_type)(result_img, gt_img, **opt_)
                    self.metric_results[folder][int(frame_idx),
                                                metric_idx] += result
#                     psnr = getattr(metric_module, metric_type)(result_img, gt_img, **opt_)
#                     with open('/home/wei/exp/EDVR/psnr_log/psnr_first.txt','a+') as f:
#                         f.write(f'{img_name} {psnr}\r\n')

# progress bar
            if rank == 0:
                for _ in range(world_size):
                    pbar.update(f'Test {folder} - '
                                f'{int(frame_idx) + world_size}/{max_idx}')

        if with_metrics:
            if self.opt['dist']:
                # collect data among GPUs
                for _, tensor in self.metric_results.items():
                    dist.reduce(tensor, 0)
                dist.barrier()
            else:
                pass  # assume use one gpu in non-dist testing

            if rank == 0:
                self._log_validation_metric_values(current_iter, dataset_name,
                                                   tb_logger)
Example #16
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        """
        TODO: Validation using updated metric system
        The metrics are now evaluated after all images have been tested
        This allows batch processing, and also allows evaluation of
        distributional metrics, such as:

        @ Frechet Inception Distance: FID
        @ Maximum Mean Discrepancy: MMD

        Warning:
            Need careful batch management for different inference settings.

        """
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = dict(
            )  # {metric: 0 for metric in self.opt['val']['metrics'].keys()}
            sr_tensors = []
            gt_tensors = []

        pbar = tqdm(total=len(dataloader), unit='image')
        for val_data in dataloader:
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals(
            )  # detached cpu tensor, non-squeeze
            sr_tensors.append(visuals['result'])
            if 'gt' in visuals:
                gt_tensors.append(visuals['gt'])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            f'{img_name}_{self.opt["name"]}.png')

                imwrite(tensor2img(visuals['result']), save_img_path)

            pbar.update(1)
            pbar.set_description(f'Test {img_name}')
        pbar.close()

        if with_metrics:
            sr_pack = torch.cat(sr_tensors, dim=0)
            gt_pack = torch.cat(gt_tensors, dim=0)
            # calculate metrics
            for name, opt_ in self.opt['val']['metrics'].items():
                # The new metric caller automatically returns mean value
                # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
                self.metric_results[name] = calculate_metric(
                    dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_)
            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
Example #17
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        save_vid = self.opt['val']['save_vid']

        if save_vid:
            dump = open(os.devnull, 'w')
            save_path = osp.join(self.opt['path']['visualization'], 'out.avi')
            fps = '30'
            crf = '18'
            vid = sp.Popen([
                'ffmpeg', '-framerate', fps, '-i', '-', '-c:v', 'libx264',
                '-preset', 'veryslow', '-crf', crf, '-y', save_path
            ],
                           stdin=sp.PIPE,
                           stderr=dump)

        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics and not hasattr(self, 'metric_results'):
            self.metric_results = {}
            num_frame_each_folder = Counter(dataset.data_info['folder'])
            for folder, num_frame in num_frame_each_folder.items():
                self.metric_results[folder] = torch.zeros(
                    num_frame,
                    len(self.opt['val']['metrics']),
                    dtype=torch.float32,
                    device='cuda')
        rank, world_size = get_dist_info()
        if with_metrics:
            for _, tensor in self.metric_results.items():
                tensor.zero_()
        # record all frames (border and center frames)
        if rank == 0:
            pbar = tqdm(total=len(dataset), unit='frame')
        for idx in range(rank, len(dataset), world_size):
            val_data = dataset[idx]
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)

            self.feed_data(val_data)
            self.test()
            visuals = self.get_current_visuals()
            result_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                folder = val_data['folder']
                frame_idx, max_idx = val_data['idx'].split('/')
                lq_path = val_data['lq_path']
                if self.opt['is_train']:
                    raise NotImplementedError(
                        'saving image is not supported during training.')
                else:
                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                    else:  # other datasets, e.g., REDS, Vid4
                        img_name = osp.splitext(osp.basename(lq_path))[0]

                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder, f'{img_name}_{self.opt["name"]}.png')
                imwrite(result_img, save_img_path)
            if self.opt['val']['save_vid']:
                frame = Image.fromarray(result_img[..., ::-1])
                frame.save(vid.stdin, 'PNG')

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for metric_idx, opt_ in enumerate(opt_metric.values()):
                    metric_type = opt_.pop('type')
                    result = getattr(metric_module,
                                     metric_type)(result_img, gt_img, **opt_)
                    self.metric_results[folder][int(frame_idx),
                                                metric_idx] += result

            # progress bar
            if rank == 0:
                for _ in range(world_size):
                    pbar.update(1)
        if rank == 0:
            pbar.close()

        vid.stdin.close()
        vid.communicate()

        if with_metrics:
            if self.opt['dist']:
                # collect data among GPUs
                for _, tensor in self.metric_results.items():
                    dist.reduce(tensor, 0)
                dist.barrier()
            else:
                pass  # assume use one gpu in non-dist testing

            if rank == 0:
                self._log_validation_metric_values(current_iter, dataset_name,
                                                   tb_logger)
    def nondist_validation(self, dataloader, current_iter, tb_logger, save_img,
                           rgb2bgr, use_image):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }
        pbar = tqdm(total=len(dataloader), unit='image')

        cnt = 0

        for idx, val_data in enumerate(dataloader):
            img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
            # if img_name[-1] != '9':
            #     continue

            # print('val_data .. ', val_data['lq'].size(), val_data['gt'].size())
            self.feed_data(val_data)
            if self.opt['val'].get('grids') is not None:
                self.grids()

            self.test()

            if self.opt['val'].get('grids') is not None:
                self.grids_inverse()

            visuals = self.get_current_visuals()
            sr_img = tensor2img([visuals['result']], rgb2bgr=rgb2bgr)
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']], rgb2bgr=rgb2bgr)
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:

                if self.opt['is_train']:

                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             img_name,
                                             f'{img_name}_{current_iter}.png')

                    save_gt_img_path = osp.join(
                        self.opt['path']['visualization'], img_name,
                        f'{img_name}_{current_iter}_gt.png')
                else:

                    save_img_path = osp.join(self.opt['path']['visualization'],
                                             dataset_name, f'{img_name}.png')
                    save_gt_img_path = osp.join(
                        self.opt['path']['visualization'], dataset_name,
                        f'{img_name}_gt.png')

                imwrite(sr_img, save_img_path)
                imwrite(gt_img, save_gt_img_path)

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                if use_image:
                    for name, opt_ in opt_metric.items():
                        metric_type = opt_.pop('type')
                        self.metric_results[name] += getattr(
                            metric_module, metric_type)(sr_img, gt_img, **opt_)
                else:
                    for name, opt_ in opt_metric.items():
                        metric_type = opt_.pop('type')
                        self.metric_results[name] += getattr(
                            metric_module, metric_type)(visuals['result'],
                                                        visuals['gt'], **opt_)

            pbar.update(1)
            pbar.set_description(f'Test {img_name}')
            cnt += 1
            # if cnt == 300:
            #     break
        pbar.close()

        current_metric = 0.
        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= cnt
                current_metric = self.metric_results[metric]

            self._log_validation_metric_values(current_iter, dataset_name,
                                               tb_logger)
        return current_metric
Example #19
0
    def nondist_validation(self, dataloader, current_iter, tb_logger,
                           save_img):
        dataset_name = dataloader.dataset.opt['name']
        with_metrics = self.opt['val'].get('metrics') is not None
        if with_metrics:
            self.metric_results = {
                metric: 0
                for metric in self.opt['val']['metrics'].keys()
            }

        pbar = tqdm(total=len(dataloader), unit='image', ascii=True)

        for idx, val_data in enumerate(dataloader):
            # val_data['key'] = val_data['key'][0]
            # val_data['frame_list'] = val_data['frame_list'][0]
            # clip_name = val_data['key'].split('/')[0]
            clip_name = val_data['clip_name'][0]
            self.feed_data(val_data)
            self.test()

            visuals = self.get_current_visuals()
            sr_imgs = tensor2img(visuals['result'])
            if 'gt' in visuals:
                gt_imgs = tensor2img(visuals['gt'])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    save_img_name = osp.join(self.opt['path']['visualization'],
                                             f'{dataset_name}_train',
                                             clip_name, '{idx:08d}.png')
                else:
                    if self.opt['val']['suffix']:
                        save_img_name = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            clip_name, ('{idx:08d}_' +
                                        f'{self.opt["val"]["suffix"]}.png'))
                    else:
                        save_img_name = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            clip_name, '{idx:08d}_.png')
                for sr_img_idx, sr_img in zip(val_data['frame_list'], sr_imgs):
                    imwrite(sr_img, save_img_name.format(idx=sr_img_idx.item()))

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for name, opt_ in opt_metric.items():
                    metric_type = opt_.pop('type')
                    metric_ = getattr(metric_module, metric_type)
                    metric_results_ = [
                        metric_(sr, gt, **opt_)
                        for sr, gt in zip(sr_imgs, gt_imgs)
                    ]
                    self.metric_results[name] += torch.tensor(
                        sum(metric_results_) / len(metric_results_))
            pbar.update(1)
            pbar.set_description(f'Test {clip_name}')
        pbar.close()

        if with_metrics:
            for metric in self.metric_results.keys():
                self.metric_results[metric] /= (idx + 1)

            super(VideoBaseModel,
                  self)._log_validation_metric_values(current_iter,
                                                      dataset_name, tb_logger)
Example #20
0
    def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
        dataset = dataloader.dataset
        dataset_name = dataset.opt['name']
        with_metrics = self.opt['val']['metrics'] is not None
        # initialize self.metric_results
        # It is a dict: {
        #    'folder1': tensor (num_frame x len(metrics)),
        #    'folder2': tensor (num_frame x len(metrics))
        # }
        if with_metrics and not hasattr(self, 'metric_results'):
            self.metric_results = {}
            num_frame_each_folder = Counter(dataset.data_info['folder'])
            for folder, num_frame in num_frame_each_folder.items():
                self.metric_results[folder] = torch.zeros(
                    num_frame,
                    len(self.opt['val']['metrics']),
                    dtype=torch.float32,
                    device='cuda')

        rank, world_size = get_dist_info()
        for _, tensor in self.metric_results.items():
            tensor.zero_()
        # record all frames (border and center frames)
        if rank == 0:
            pbar = ProgressBar(len(dataset))
        for idx in range(rank, len(dataset), world_size):
            val_data = dataset[idx]
            val_data['lq'].unsqueeze_(0)
            val_data['gt'].unsqueeze_(0)
            folder = val_data['folder']
            frame_idx, max_idx = val_data['idx'].split('/')
            lq_path = val_data['lq_path']

            self.feed_data(val_data)
            self.test()

            # torch.cuda.synchronize()
            # t0 = time.time()
            # nRound = 10
            # for i in range(nRound):
            #     self.test()
            # torch.cuda.synchronize()
            # print('Prediction time: ', (time.time() - t0) / nRound, '\n\n')

            visuals = self.get_current_visuals()
            result_img = tensor2img([visuals['result']])
            if 'gt' in visuals:
                gt_img = tensor2img([visuals['gt']])
                del self.gt

            # tentative for out of GPU memory
            del self.lq
            del self.output
            torch.cuda.empty_cache()

            if save_img:
                if self.opt['is_train']:
                    raise NotImplementedError(
                        'saving image is not supported during training.')
                else:
                    if 'vimeo' in dataset_name.lower():  # vimeo90k dataset
                        split_result = lq_path.split('/')
                        img_name = (f'{split_result[-3]}_{split_result[-2]}_'
                                    f'{split_result[-1].split(".")[0]}')
                    else:  # other datasets, e.g., REDS, Vid4
                        img_name = osp.splitext(osp.basename(lq_path))[0]

                    if self.opt['val']['suffix']:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder,
                            f'{img_name}_{self.opt["val"]["suffix"]}.png')
                    else:
                        save_img_path = osp.join(
                            self.opt['path']['visualization'], dataset_name,
                            folder, f'{img_name}_{self.opt["name"]}.png')
                mmcv.imwrite(result_img, save_img_path)

            if with_metrics:
                # calculate metrics
                opt_metric = deepcopy(self.opt['val']['metrics'])
                for metric_idx, opt_ in enumerate(opt_metric.values()):
                    metric_type = opt_.pop('type')
                    result = getattr(metric_module,
                                     metric_type)(result_img, gt_img, **opt_)
                    self.metric_results[folder][int(frame_idx),
                                                metric_idx] += result

            # progress bar
            if rank == 0:
                for _ in range(world_size):
                    pbar.update(f'Test {folder} - '
                                f'{int(frame_idx) + world_size}/{max_idx}')