def main(entry: str, ddf_file):
    entry = entry.upper()
    all_data = load_datasets(ddf_file)
    if entry not in all_data:
        raise KeyError(f"The dataset `{entry}` not found in the DDF")

    data = all_data.get(entry)
    print(f"Dataset: {data.name}")

    def _check(name: str):
        print(f"\n=========  CHECKING  {name}  =========\n")
        if name in data and data[name] is not None:
            print(f"Found `{name}` set in \"{data.name}\":")
            _hr = data[name].hr
            _lr = data[name].lr
            video_type = _hr.as_video
            if video_type:
                print(f"\"{data.name}\" is video data")
            if _hr is not None:
                _hr = _hr.compile()
                print(f"Found {len(_hr)} ground-truth {name} data")
            if _lr is not None:
                _lr = _lr.compile()
                print(f"Found {len(_lr)} custom degraded {name} data")
                if len(_hr) != len(_lr):
                    print(
                        f" [E] Ground-truth data and degraded data quantity not matched!!"
                    )
                elif video_type:
                    for x, y in zip(_hr, _lr):
                        if x.frames != y.frames:
                            print(
                                f" [E] Video clip {x.name}|{y.name} quantity not matched!!"
                            )
        else:
            print(f"{data.name} doesn't contain any {name} data.")

    _check('train')
    _check('val')
    _check('test')
Exemple #2
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    overwrite_from_env(opt)
    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment
    root = Path(root)

    datasets = load_datasets(data_config_file)
    try:
        test_datas = [datasets[t.upper()]
                      for t in opt.test] if opt.test else []
    except KeyError:
        test_datas = [Config(test=Config(lr=Dataset(*opt.test)), name='infer')]
        if opt.video:
            test_datas[0].test.lr.use_like_video_()
    # enter model executor environment
    with model.get_executor(root) as t:
        for data in test_datas:
            run_benchmark = False if data.test.hr is None else True
            if run_benchmark:
                ld = Loader(data.test.hr,
                            data.test.lr,
                            opt.scale,
                            threads=opt.threads)
            else:
                ld = Loader(data.test.hr, data.test.lr, threads=opt.threads)
            if opt.channel == 1:
                # convert data color space to grayscale
                ld.set_color_space('hr', 'L')
                ld.set_color_space('lr', 'L')
            config = t.query_config(opt)
            config.inference_results_hooks = [
                save_inference_images(root / data.name, opt.output_index,
                                      opt.auto_rename)
            ]
            if run_benchmark:
                t.benchmark(ld, config)
            else:
                t.infer(ld, config)
        if opt.export:
            t.export(opt.export)
Exemple #3
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()  # An EasyDict object
    # overwrite flag values into opt object
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    # fetch dataset descriptions
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment

    dataset = load_datasets(data_config_file, opt.dataset)
    # construct data loader for training
    lt = Loader(dataset.train.hr,
                dataset.train.lr,
                opt.scale,
                threads=opt.threads)
    lt.image_augmentation()
    # construct data loader for validating
    lv = None
    if dataset.val is not None:
        lv = Loader(dataset.val.hr,
                    dataset.val.lr,
                    opt.scale,
                    threads=opt.threads)
    lt.cropper(RandomCrop(opt.scale))
    if opt.traced_val and lv is not None:
        lv.cropper(CenterCrop(opt.scale))
    elif lv is not None:
        lv.cropper(RandomCrop(opt.scale))
    if opt.channel == 1:
        # convert data color space to grayscale
        lt.set_color_space('hr', 'L')
        lt.set_color_space('lr', 'L')
        if lv is not None:
            lv.set_color_space('hr', 'L')
            lv.set_color_space('lr', 'L')
    # enter model executor environment
    with model.get_executor(root) as t:
        config = t.query_config(opt)
        if opt.lr_decay:
            config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
        t.fit([lt, lv], config)
        if opt.export:
            t.export(opt.export)