def test_infer_srcnn():
    m = get_model('srcnn')(scale=2, channel=3)
    data = Dataset('data').include_reg('set5')
    ld = Loader(data, scale=2)
    with m.executor as t:
        config = t.query_config({})
        t.infer(ld, config)
def test_train_vespcn():
    data = Dataset('data/video').include_reg("xiuxian").use_like_video()
    ld = Loader(data, scale=2)
    m = get_model('vespcn')(scale=2, channel=3)
    with m.executor as t:
        config = t.query_config({})
        config.epochs = 1
        config.steps = 10
        if DATA_FORMAT == 'channels_first':
            config.batch_shape = [16, 3, 3, 16, 16]
        else:
            config.batch_shape = [16, 3, 16, 16, 3]
        t.fit([ld, None], config)
def test_train_srcnn():
    data = Dataset('data').include_reg('set5')
    ld = Loader(data, scale=2)
    ld.set_color_space('lr', 'L')
    ld.set_color_space('hr', 'L')
    m = get_model('srcnn')(scale=2, channel=1)
    with m.executor as t:
        config = t.query_config({})
        config.epochs = 5
        config.steps = 10
        if DATA_FORMAT == 'channels_first':
            config.batch_shape = [16, 1, 16, 16]
        else:
            config.batch_shape = [16, 16, 16, 1]
        t.fit([ld, None], config)
示例#4
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    overwrite_from_env(opt)
    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment
    root = Path(root)

    datasets = load_datasets(data_config_file)
    try:
        test_datas = [datasets[t.upper()]
                      for t in opt.test] if opt.test else []
    except KeyError:
        test_datas = [Config(test=Config(lr=Dataset(*opt.test)), name='infer')]
        if opt.video:
            test_datas[0].test.lr.use_like_video_()
    # enter model executor environment
    with model.get_executor(root) as t:
        for data in test_datas:
            run_benchmark = False if data.test.hr is None else True
            if run_benchmark:
                ld = Loader(data.test.hr,
                            data.test.lr,
                            opt.scale,
                            threads=opt.threads)
            else:
                ld = Loader(data.test.hr, data.test.lr, threads=opt.threads)
            if opt.channel == 1:
                # convert data color space to grayscale
                ld.set_color_space('hr', 'L')
                ld.set_color_space('lr', 'L')
            config = t.query_config(opt)
            config.inference_results_hooks = [
                save_inference_images(root / data.name, opt.output_index,
                                      opt.auto_rename)
            ]
            if run_benchmark:
                t.benchmark(ld, config)
            else:
                t.infer(ld, config)
        if opt.export:
            t.export(opt.export)
示例#5
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()  # An EasyDict object
    # overwrite flag values into opt object
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)
    # fetch dataset descriptions
    data_config_file = Path(opt.data_config)
    if not data_config_file.exists():
        raise FileNotFoundError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        if opt.parameter:
            model_config_file = Path(opt.parameter)
        else:
            model_config_file = Path(f'par/{BACKEND}/{opt.model}.{_ext}')
        if model_config_file.exists():
            opt.update(compat_param(Config(str(model_config_file))))
    # get model parameters from pre-defined YAML file
    model_params = opt.get(opt.model, {})
    suppress_opt_by_args(model_params, *args)
    opt.update(model_params)
    # construct model
    model = get_model(opt.model)(**model_params)
    if opt.cuda:
        model.cuda()
    if opt.pretrain:
        model.load(opt.pretrain)
    root = f'{opt.save_dir}/{opt.model}'
    if opt.comment:
        root += '_' + opt.comment

    dataset = load_datasets(data_config_file, opt.dataset)
    # construct data loader for training
    lt = Loader(dataset.train.hr,
                dataset.train.lr,
                opt.scale,
                threads=opt.threads)
    lt.image_augmentation()
    # construct data loader for validating
    lv = None
    if dataset.val is not None:
        lv = Loader(dataset.val.hr,
                    dataset.val.lr,
                    opt.scale,
                    threads=opt.threads)
    lt.cropper(RandomCrop(opt.scale))
    if opt.traced_val and lv is not None:
        lv.cropper(CenterCrop(opt.scale))
    elif lv is not None:
        lv.cropper(RandomCrop(opt.scale))
    if opt.channel == 1:
        # convert data color space to grayscale
        lt.set_color_space('hr', 'L')
        lt.set_color_space('lr', 'L')
        if lv is not None:
            lv.set_color_space('hr', 'L')
            lv.set_color_space('lr', 'L')
    # enter model executor environment
    with model.get_executor(root) as t:
        config = t.query_config(opt)
        if opt.lr_decay:
            config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
        t.fit([lt, lv], config)
        if opt.export:
            t.export(opt.export)