예제 #1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seeds(args.seed)

    modules = prepare_modules(expdir=args.expdir)

    model = Registry.get_model(**config["model_params"])
    datasource = modules["data"].DataSource()
    data_params = config.get("data_params", {}) or {}
    loaders = datasource.prepare_loaders(mode="infer",
                                         n_workers=args.workers,
                                         batch_size=args.batch_size,
                                         **data_params)

    runner = modules["model"].ModelRunner(model=model)
    callbacks_params = config.get("callbacks_params", {}) or {}
    callbacks = runner.prepare_callbacks(mode="infer",
                                         resume=args.resume,
                                         out_prefix=args.out_prefix,
                                         **callbacks_params)
    runner.infer(loaders=loaders, callbacks=callbacks, verbose=args.verbose)
예제 #2
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)
    set_global_seeds(args.seed)

    assert args.baselogdir is not None or args.logdir is not None

    if args.logdir is None:
        modules_ = prepare_modules(expdir=args.expdir)
        logdir = modules_["model"].prepare_logdir(config=config)
        args.logdir = str(pathlib.Path(args.baselogdir).joinpath(logdir))

    os.makedirs(args.logdir, exist_ok=True)
    save_config(config=config, logdir=args.logdir)
    modules = prepare_modules(expdir=args.expdir, dump_dir=args.logdir)

    model = Registry.get_model(**config["model_params"])
    datasource = modules["data"].DataSource()

    runner = modules["model"].ModelRunner(model=model)
    runner.train_stages(datasource=datasource,
                        args=args,
                        stages_config=config["stages"],
                        verbose=args.verbose)
예제 #3
0
 def get_model(self, stage: str) -> _Model:
     model = Registry.get_model(**self._config["model_params"])
     model = self._preprocess_model_for_stage(stage, model)
     model = self._postprocess_model_for_stage(stage, model)
     return model