Exemplo n.º 1
0
def init_loader_config(opt):
    train_config = Config(**opt, crop='random', feature_callbacks=[], label_callbacks=[])
    benchmark_config = Config(**opt, crop=None, feature_callbacks=[], label_callbacks=[], output_callbacks=[])
    infer_config = Config(**opt, feature_callbacks=[], label_callbacks=[], output_callbacks=[])
    benchmark_config.batch = opt.test_batch or 1
    benchmark_config.steps_per_epoch = -1
    if opt.channel == 1:
        train_config.convert_to = 'gray'
        benchmark_config.convert_to = 'gray'
        if opt.output_color == 'RGB':
            benchmark_config.convert_to = 'yuv'
            benchmark_config.feature_callbacks = train_config.feature_callbacks + [to_gray()]
            benchmark_config.label_callbacks = train_config.label_callbacks + [to_gray()]
            benchmark_config.output_callbacks = [to_rgb()]
        benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)]
        infer_config.update(benchmark_config)
    else:
        train_config.convert_to = 'rgb'
        benchmark_config.convert_to = 'rgb'
        benchmark_config.output_callbacks += [save_image(opt.root, opt.output_index)]
        infer_config.update(benchmark_config)
    if opt.add_custom_callbacks is not None:
        for fn in opt.add_custom_callbacks:
            train_config.feature_callbacks += [globals()[fn]]
            benchmark_config.feature_callbacks += [globals()[fn]]
            infer_config.feature_callbacks += [globals()[fn]]
    if opt.lr_decay:
        train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
    # modcrop: A boolean to specify whether to crop the edge of images to be divisible
    #          by `scale`. It's useful when to provide batches with original shapes.
    infer_config.modcrop = False
    return train_config, benchmark_config, infer_config
Exemplo n.º 2
0
def main():
    flags, args = parser.parse_known_args()
    opt = Config()
    for pair in flags._get_kwargs():
        opt.setdefault(*pair)

    data_config_file = Path(flags.data_config)
    if not data_config_file.exists():
        raise RuntimeError("dataset config file doesn't exist!")
    for _ext in ('json', 'yaml', 'yml'):  # for compat
        # apply a 2-stage (or master-slave) configuration, master can be
        # override by slave
        model_config_root = Path('Parameters/root.{}'.format(_ext))
        if opt.p:
            model_config_file = Path(opt.p)
        else:
            model_config_file = Path('Parameters/{}.{}'.format(
                opt.model, _ext))
        if model_config_root.exists():
            opt.update(Config(str(model_config_root)))
        if model_config_file.exists():
            opt.update(Config(str(model_config_file)))

    model_params = opt.get(opt.model, {})
    opt.update(model_params)
    suppress_opt_by_args(model_params, *args)
    model = get_model(flags.model)(**model_params)
    if flags.cuda:
        model.cuda()
    root = f'{flags.save_dir}/{flags.model}'
    if flags.comment:
        root += '_' + flags.comment
    verbosity = logging.DEBUG if flags.verbose else logging.INFO
    trainer = model.trainer

    datasets = load_datasets(data_config_file)
    dataset = datasets[flags.dataset.upper()]

    train_config = Config(crop=opt.train_data_crop,
                          feature_callbacks=[],
                          label_callbacks=[],
                          convert_to='rgb',
                          **opt)
    if opt.channel == 1:
        train_config.convert_to = 'gray'
    if opt.lr_decay:
        train_config.lr_schedule = lr_decay(lr=opt.lr, **opt.lr_decay)
    train_config.random_val = not opt.traced_val
    train_config.cuda = flags.cuda

    if opt.verbose:
        dump(opt)
    with trainer(model, root, verbosity, opt.pth) as t:
        if opt.seed is not None:
            t.set_seed(opt.seed)
        tloader = QuickLoader(dataset, 'train', train_config, True,
                              flags.thread)
        vloader = QuickLoader(dataset,
                              'val',
                              train_config,
                              False,
                              flags.thread,
                              batch=1,
                              crop=opt.val_data_crop,
                              steps_per_epoch=opt.val_num)
        t.fit([tloader, vloader], train_config)
        if opt.export:
            t.export(opt.export)