예제 #1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)
    pprint(args)
    pprint(config)
    set_global_seeds(args.seed)

    assert args.baselogdir is not None or args.logdir is not None

    if args.logdir is None:
        modules_ = prepare_modules(model_dir=args.model_dir)
        logdir = modules_["model"].prepare_logdir(config=config)
        args.logdir = str(pathlib2.Path(args.baselogdir).joinpath(logdir))

    create_if_need(args.logdir)
    save_config(config=config, logdir=args.logdir)
    modules = prepare_modules(model_dir=args.model_dir, dump_dir=args.logdir)

    datasource = modules["data"].DataSource()
    model = modules["model"].prepare_model(config)

    runner = modules["model"].ModelRunner(model=model)
    runner.train(datasource=datasource,
                 args=args,
                 stages_config=config["stages"],
                 verbose=args.verbose)
예제 #2
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    pprint(args)
    pprint(config)
    set_global_seeds(args.seed)

    modules = prepare_modules(model_dir=args.model_dir)

    datasource = modules["data"].DataSource()
    loaders = datasource.prepare_loaders(args, **config["data_params"])
    model = modules["model"].prepare_model(config)

    runner = modules["model"].ModelRunner(model=model)
    callbacks = runner.prepare_callbacks(
        callbacks_params=config["callbacks_params"], args=args, mode="infer")
    runner.infer(loaders=loaders, callbacks=callbacks, verbose=args.verbose)
예제 #3
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    pprint(args)
    pprint(config)
    set_global_seeds(args.seed)

    modules = prepare_modules(model_dir=args.model_dir)

    datasource = modules["data"].DataSource()
    data_params = config.get("data_params", {}) or {}
    loaders = datasource.prepare_loaders(mode="infer",
                                         n_workers=args.workers,
                                         batch_size=args.batch_size,
                                         **data_params)
    model = modules["model"].prepare_model(config)

    runner = modules["model"].ModelRunner(model=model)
    callbacks_params = config.get("callbacks_params", {}) or {}
    callbacks = runner.prepare_callbacks(mode="infer",
                                         resume=args.resume,
                                         out_prefix=args.out_prefix,
                                         **callbacks_params)
    runner.infer(loaders=loaders, callbacks=callbacks, verbose=args.verbose)
예제 #4
0
def algos_by_dir(dir):
    algos = []
    dirs = path.Path(dir).listdir()
    for logpath in dirs:
        config_path = logpath + "/config.json"
        checkpoints = path.Path(logpath).glob("*.pth.tar")
        for checkpoint_path in checkpoints:
            args = argparse.Namespace(config=config_path)
            args, config = parse_args_uargs(args, [])
            config.get("algorithm", {}).pop("resume", None)
            config.get("algorithm", {}).pop("load_optimizer", None)

            algo_module = import_module("algo_module", args.algorithm)
            trainer_kwargs = algo_module.prepare_for_trainer(config)

            algorithm = trainer_kwargs["algorithm"]
            algorithm.load_checkpoint(checkpoint_path, load_optimizer=False)

            actor_ = algorithm.actor.eval()

            name = str(algorithm.__class__).lower()
            if "ensemblecritic" in name:
                critics_ = [x.eval() for x in algorithm.critics]
            elif "td3" in name:
                critics_ = [algorithm.critic.eval(), algorithm.critic2.eval()]
            else:
                raise NotImplemented

            history_len = trainer_kwargs["history_len"]

            algos.append(
                AlgoWrapper(actor=actor_,
                            critics=critics_,
                            history_len=history_len,
                            consensus=IN_CONSENSUS))
    return algos
예제 #5
0
parser = argparse.ArgumentParser()
parser.add_argument(
    "--config",
    type=str,
    required=True)
parser.add_argument(
    "--algorithm",
    type=str,
    default=None)
parser.add_argument(
    "--logdir",
    type=str,
    default=None)
args, unknown_args = parser.parse_known_args()
args, config = parse_args_uargs(args, unknown_args, dump_config=True)

algo_module = import_module("algo_module", args.algorithm)
algo_kwargs = algo_module.prepare_for_trainer(config)

redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000))
redis_prefix = config.get("redis", {}).get("prefix", "")

pprint(config["trainer"])
pprint(algo_kwargs)


trainer = Trainer(
    **config["trainer"],
    **algo_kwargs,
    logdir=args.logdir,
예제 #6
0
parser.add_argument(
    "--max-noise-power",
    type=float,
    default=None)
parser.add_argument(
    "--max-action-noise",
    type=float,
    default=None)
parser.add_argument(
    "--max-param-noise",
    type=float,
    default=None)
boolean_flag(parser, "debug", default=False)

args = parser.parse_args()
args, config = parse_args_uargs(args, [])

env_module = import_module("env_module", args.environment)
algo_module = import_module("algo_module", args.algorithm)


def run_sampler(
        *,
        config, vis, infer,
        action_noise_prob, param_noise_prob,
        action_noise=None, param_noise=None,
        noise_power=None,  # @TODO: remove
        id=None, resume=None, debug=False):
    config_ = copy.deepcopy(config)

    if debug: