def __init__(self, data_opt, **kwargs):
        """ Folder dataset with paired data
            support both BI & BD degradation
        """
        super(ValidationDataset, self).__init__(data_opt, **kwargs)

        # get keys
        gt_keys = sorted(os.listdir(self.gt_seq_dir))
        self.keys = sorted(list(set(gt_keys)))
        if data_opt['name'].startswith('Actors'):
            for i, k in enumerate(self.keys):
                self.keys[i] = k + '/frames'

        self.kernel = create_kernel({
            'dataset':{
                'degradation': {
                    'sigma': self.sigma
                }
            },
            'device': 'cuda'
        })

        # filter keys
        if self.filter_file:
            with open(self.filter_file, 'r') as f:
                sel_keys = { line.strip() for line in f }
                self.keys = sorted(list(sel_keys & set(self.keys)))
Exemple #2
0
    def prepare_training_data(self, data):
        """ prepare gt, lr data for training

            for BD degradation, generate lr data and remove the border of gt data
            for BI degradation, use input data directly
        """

        degradation_type = self.opt['dataset']['degradation']['type']

        if degradation_type == 'BI':
            self.gt_data = data['gt'].to(self.device)
            self.lr_data = data['lr'].to(self.device)

        elif degradation_type == 'BD':
            # generate lr data on the fly (on gpu)

            # set params
            scale = self.opt['scale']
            sigma = self.opt['dataset']['degradation'].get('sigma', 1.5)
            border_size = int(sigma * 3.0)

            gt_data = data['gt'].to(self.device)  # with border
            n, t, c, gt_h, gt_w = gt_data.size()
            lr_h = (gt_h - 2 * border_size) // scale
            lr_w = (gt_w - 2 * border_size) // scale

            # create blurring kernel
            if self.blur_kernel is None:
                self.blur_kernel = create_kernel(sigma).to(self.device)
            blur_kernel = self.blur_kernel

            # generate lr data
            gt_data = gt_data.view(n * t, c, gt_h, gt_w)
            lr_data = downsample_bd(gt_data,
                                    blur_kernel,
                                    scale,
                                    pad_data=False)
            lr_data = lr_data.view(n, t, c, lr_h, lr_w)

            # remove gt border
            gt_data = gt_data[..., border_size:border_size + scale * lr_h,
                              border_size:border_size + scale * lr_w]
            gt_data = gt_data.view(n, t, c, scale * lr_h, scale * lr_w)

            self.gt_data, self.lr_data = gt_data, lr_data  # tchw|float32
Exemple #3
0
def downscale_data(opt):
    for dataset_idx in sorted(opt['dataset'].keys()):
        if not dataset_idx.startswith('all'):
            continue

        loader = create_dataloader(opt, dataset_idx=dataset_idx)
        degradation_type = opt['dataset']['degradation']['type']
        if degradation_type == 'BD':
            kernel = data_utils.create_kernel(opt)

        if degradation_type == 'Style':
            path = opt['cartoon_model']
            cartoonizer = SimpleGenerator().to(torch.device(opt['device']))
            cartoonizer.load_weights(path)
            cartoonizer.eval()

        for item in tqdm(loader, ascii=True):
            if degradation_type == 'BD':
                data = prepare_data(opt, item, kernel)
            elif degradation_type == 'BI':
                data = data_utils.BI_downsample(opt, item)
            elif degradation_type == 'Style':
                image = item['gt'][0]
                image = resize(image)
                image = image.to(torch.device(opt['device']))
                with torch.no_grad():
                    stylized_image = cartoonizer(image).unsqueeze(0)
                    stylized_image = (stylized_image + 1) * 0.5
                data = {'gt': image.unsqueeze(0), 'lr': stylized_image}
            lr_data = data['lr']
            gt_data = data['gt']
            img = lr_data.squeeze(0).squeeze(0).permute(1, 2, 0).cpu().numpy()
            path = osp.join(
                'data', opt['dataset']['common']['name'], opt['data_subset'],
                opt['dataset'][dataset_idx]['actor_name'],
                opt['data_type'] + '_' + opt['dataset']['degradation']['type'],
                opt['dataset'][dataset_idx]['segment'], 'frames')
            os.makedirs(path, exist_ok=True)
            path = osp.join(path, item['frame_key'][0])
            img = img * 255.0
            img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)
            cv2.imwrite(path, img)
Exemple #4
0
    def prepare_inference_data(self, data):
        """ Prepare lr data for training (w/o loading on device)
        """

        degradation_type = self.opt['dataset']['degradation']['type']

        if degradation_type == 'BI':
            self.lr_data = data['lr']

        elif degradation_type == 'BD':
            if 'lr' in data:
                self.lr_data = data['lr']
            else:
                # generate lr data on the fly (on cpu)
                # TODO: do frame-wise downsampling on gpu for acceleration?
                gt_data = data['gt']  # thwc|uint8

                # set params
                scale = self.opt['scale']
                sigma = self.opt['dataset']['degradation'].get('sigma', 1.5)

                # create blurring kernel
                if self.blur_kernel is None:
                    self.blur_kernel = create_kernel(sigma)
                blur_kernel = self.blur_kernel.cpu()

                # generate lr data
                gt_data = gt_data.permute(0, 3, 1,
                                          2).float() / 255.0  # tchw|float32
                lr_data = downsample_bd(gt_data,
                                        blur_kernel,
                                        scale,
                                        pad_data=True)
                lr_data = lr_data.permute(0, 2, 3, 1)  # thwc|float32

                self.lr_data = lr_data

        # thwc to tchw
        self.lr_data = self.lr_data.permute(0, 3, 1, 2)  # tchw|float32
Exemple #5
0
def train(opt):
    # logging
    logger = base_utils.get_logger('base')
    logger.info('{} Options {}'.format('='*20, '='*20))
    base_utils.print_options(opt, logger)

    # create data loader
    train_loader = create_dataloader(opt, dataset_idx='train')

    # create downsampling kernels for BD degradation
    kernel = data_utils.create_kernel(opt)

    # create model
    model = define_model(opt)

    # training configs
    total_sample = len(train_loader.dataset)
    iter_per_epoch = len(train_loader)
    total_iter = opt['train']['total_iter']
    total_epoch = int(math.ceil(total_iter / iter_per_epoch))
    start_iter, iter = opt['train']['start_iter'], 0

    test_freq = opt['test']['test_freq']
    log_freq = opt['logger']['log_freq']
    ckpt_freq = opt['logger']['ckpt_freq']

    logger.info('Number of training samples: {}'.format(total_sample))
    logger.info('Total epochs needed: {} for {} iterations'.format(
        total_epoch, total_iter))

    # train
    for epoch in range(total_epoch):
        for data in train_loader:
            # update iter
            iter += 1
            curr_iter = start_iter + iter
            if iter > total_iter:
                logger.info('Finish training')
                break

            # update learning rate
            model.update_learning_rate()

            # prepare data
            data = prepare_data(opt, data, kernel)

            # train for a mini-batch
            model.train(data)

            # update running log
            model.update_running_log()

            # log
            if log_freq > 0 and iter % log_freq == 0:
                # basic info
                msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter)
                for lr_type, lr in model.get_current_learning_rate().items():
                    msg += ' | {}: {:.2e}'.format(lr_type, lr)
                msg += '] '

                # loss info
                log_dict = model.get_running_log()
                msg += ', '.join([
                    '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()])

                logger.info(msg)

            # save model
            if ckpt_freq > 0 and iter % ckpt_freq == 0:
                model.save(curr_iter)

            # evaluate performance
            if test_freq > 0 and iter % test_freq == 0:
                # setup model index
                model_idx = 'G_iter{}'.format(curr_iter)

                # for each testset
                for dataset_idx in sorted(opt['dataset'].keys()):
                    # use dataset with prefix `test`
                    if not dataset_idx.startswith('test'):
                        continue

                    ds_name = opt['dataset'][dataset_idx]['name']
                    logger.info(
                        'Testing on {}: {}'.format(dataset_idx, ds_name))

                    # create data loader
                    test_loader = create_dataloader(opt, dataset_idx=dataset_idx)

                    # define metric calculator
                    metric_calculator = MetricCalculator(opt)

                    # infer and compute metrics for each sequence
                    for data in test_loader:
                        # fetch data
                        lr_data = data['lr'][0]
                        seq_idx = data['seq_idx'][0]
                        frm_idx = [frm_idx[0] for frm_idx in data['frm_idx']]

                        # infer
                        hr_seq = model.infer(lr_data)  # thwc|rgb|uint8

                        # save results (optional)
                        if opt['test']['save_res']:
                            res_dir = osp.join(
                                opt['test']['res_dir'], ds_name, model_idx)
                            res_seq_dir = osp.join(res_dir, seq_idx)
                            data_utils.save_sequence(
                                res_seq_dir, hr_seq, frm_idx, to_bgr=True)

                        # compute metrics for the current sequence
                        true_seq_dir = osp.join(
                            opt['dataset'][dataset_idx]['gt_seq_dir'], seq_idx)
                        metric_calculator.compute_sequence_metrics(
                            seq_idx, true_seq_dir, '', pred_seq=hr_seq)

                    # save/print metrics
                    if opt['test'].get('save_json'):
                        # save results to json file
                        json_path = osp.join(
                            opt['test']['json_dir'], '{}_avg.json'.format(ds_name))
                        metric_calculator.save_results(
                            model_idx, json_path, override=True)
                    else:
                        # print directly
                        metric_calculator.display_results()
def train(opt):
    # logging
    logger = base_utils.get_logger('base')
    logger.info('{} Options {}'.format('='*20, '='*20))
    base_utils.print_options(opt, logger)

    # create data loader
    train_loader = create_dataloader(opt, dataset_idx='train')

    # create downsampling kernels for BD degradation
    kernel = data_utils.create_kernel(opt)

    # create model
    model = define_model(opt)

    # training configs
    total_sample = len(train_loader.dataset)
    iter_per_epoch = len(train_loader)
    total_iter = opt['train']['total_iter']
    total_epoch = int(math.ceil(total_iter / iter_per_epoch))
    curr_iter = opt['train']['start_iter']

    test_freq = opt['test']['test_freq']
    log_freq = opt['logger']['log_freq']
    ckpt_freq = opt['logger']['ckpt_freq']
    sigma_freq = opt['dataset']['degradation'].get('sigma_freq', 0)
    sigma_inc = opt['dataset']['degradation'].get('sigma_inc', 0)
    sigma_max = opt['dataset']['degradation'].get('sigma_max', 10)

    logger.info('Number of training samples: {}'.format(total_sample))
    logger.info('Total epochs needed: {} for {} iterations'.format(
        total_epoch, total_iter))
    print('device count:', torch.cuda.device_count())
    # train
    for epoch in range(total_epoch):
        for data in tqdm(train_loader):
            # update iter
            curr_iter += 1
            if curr_iter > total_iter:
                logger.info('Finish training')
                break

            # update learning rate
            model.update_learning_rate()

            # prepare data
            data = prepare_data(opt, data, kernel)

            # train for a mini-batch
            model.train(data)

            # update running log
            model.update_running_log()

            # log
            if log_freq > 0 and curr_iter % log_freq == 0:
                # basic info
                msg = '[epoch: {} | iter: {}'.format(epoch, curr_iter)
                for lr_type, lr in model.get_current_learning_rate().items():
                    msg += ' | {}: {:.2e}'.format(lr_type, lr)
                msg += '] '

                # loss info
                log_dict = model.get_running_log()
                msg += ', '.join([
                    '{}: {:.3e}'.format(k, v) for k, v in log_dict.items()])
                if opt['dataset']['degradation']['type'] == 'BD':
                    msg += ' | Sigma: {}'.format(opt['dataset']['degradation']['sigma'])
                logger.info(msg)

            # save model
            if ckpt_freq > 0 and curr_iter % ckpt_freq == 0:
                model.save(curr_iter)

            # evaluate performance
            if test_freq > 0 and curr_iter % test_freq == 0:
                # setup model index
                model_idx = 'G_iter{}'.format(curr_iter)
                if opt['dataset']['degradation']['type'] == 'BD':
                    model_idx = model_idx + str(opt['dataset']['degradation']['sigma'])

                # for each testset
                for dataset_idx in sorted(opt['dataset'].keys()):
                    # use dataset with prefix `test`
                    if not dataset_idx.startswith('validate'):
                        continue
                    validate(opt, model, logger, dataset_idx, model_idx)

        # schedule sigma
        if opt['dataset']['degradation']['type'] == 'BD':
            if sigma_freq > 0 and (epoch + 1) % sigma_freq == 0:
                current_sigma = opt['dataset']['degradation']['sigma']
                opt['dataset']['degradation']['sigma'] = min(current_sigma + sigma_inc, sigma_max)
                kernel = data_utils.create_kernel(opt)
                
                # __getitem__ in custom dataset class uses some crop that depends sigma
                # it is crucial to change this cropsize accordingly if sigma is being changed
                train_loader.dataset.change_cropsize(opt['dataset']['degradation']['sigma'])
                print('kernel changed')