Ejemplo n.º 1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_environment(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841
        if args.logdir is not None:
            dump_code(args.expdir, args.logdir)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        checkpoint = utils.load_checkpoint(filepath=args.resume)
        checkpoint = utils.any2device(checkpoint, utils.get_device())
        algorithm.unpack_checkpoint(
            checkpoint=checkpoint,
            with_optimizer=False
        )

    monitoring_params = config.get("monitoring_params", None)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        monitoring_params=monitoring_params,
        **config["trainer"],
    )

    trainer.run()
Ejemplo n.º 2
0
def run_sampler(
    *,
    config,
    logdir,
    algorithm_fn,
    environment_fn,
    sampler_fn,
    vis,
    infer,
    seed=42,
    id=None,
    resume=None,
    db=True,
    exploration_power=1.0,
    sync_epoch=False
):
    config_ = copy.deepcopy(config)
    id = 0 if id is None else id
    set_global_seed(seed + id)

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    ) if db else None

    env = environment_fn(**config_["environment"], visualize=vis)
    agent = algorithm_fn.prepare_for_sampler(env_spec=env, config=config_)

    exploration_params = config_["sampler"].pop("exploration_params", None)
    exploration_handler = ExplorationHandler(env=env, *exploration_params) \
        if exploration_params is not None \
        else None
    if exploration_handler is not None:
        exploration_handler.set_power(exploration_power)

    mode = "infer" if infer else "train"
    valid_seeds = config_["sampler"].pop("valid_seeds")
    seeds = valid_seeds if infer else None

    sampler = sampler_fn(
        agent=agent,
        env=env,
        db_server=db_server,
        exploration_handler=exploration_handler,
        **config_["sampler"],
        logdir=logdir,
        id=id,
        mode=mode,
        seeds=seeds
    )

    if resume is not None:
        sampler.load_checkpoint(filepath=resume)

    sampler.run()
Ejemplo n.º 3
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_config(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
        weights_sync_mode = "critic" if env.discrete_actions else "actor"
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
        weights_sync_mode = "actor"
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        algorithm.load_checkpoint(filepath=args.resume)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        weights_sync_mode=weights_sync_mode,
        **config["trainer"],
    )

    trainer.run()
Ejemplo n.º 4
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_config(args.configs, args.logdir)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    else:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True

    db_server = DATABASES.get_from_params(**config.get("db", {}),
                                          sync_epoch=sync_epoch)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        algorithm.load_checkpoint(filepath=args.resume)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        **config["trainer"],
        logdir=args.logdir,
    )

    def on_exit():
        for p in trainer.get_processes():
            p.terminate()

    atexit.register(on_exit)

    trainer.run()
Ejemplo n.º 5
0
def run_sampler(*,
                config,
                logdir,
                algorithm_fn,
                environment_fn,
                visualize,
                mode,
                seed=42,
                id=None,
                resume=None,
                db=True,
                exploration_power=1.0,
                sync_epoch=False):
    config_ = copy.deepcopy(config)
    id = 0 if id is None else id
    seed = seed + id
    set_global_seed(seed)

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch) if db else None

    env = environment_fn(
        **config_["environment"],
        visualize=visualize,
        mode=mode,
        sampler_id=id,
    )
    agent = algorithm_fn.prepare_for_sampler(env_spec=env, config=config_)

    exploration_params = config_["sampler"].pop("exploration_params", None)
    exploration_handler = ExplorationHandler(env=env, *exploration_params) \
        if exploration_params is not None \
        else None
    if exploration_handler is not None:
        exploration_handler.set_power(exploration_power)

    seeds = dict((k, config_["sampler"].pop(f"{k}_seeds", None))
                 for k in ["train", "valid", "infer"])
    seeds = seeds[mode]

    if algorithm_fn in OFFPOLICY_ALGORITHMS.values():
        weights_sync_mode = "critic" if env.discrete_actions else "actor"
    elif algorithm_fn in ONPOLICY_ALGORITHMS.values():
        weights_sync_mode = "actor"
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    if mode in ["valid"]:
        sampler_fn = ValidSampler
    else:
        sampler_fn = Sampler

    monitoring_params = config.get("monitoring_params", None)

    sampler = sampler_fn(
        agent=agent,
        env=env,
        db_server=db_server,
        exploration_handler=exploration_handler,
        logdir=logdir,
        id=id,
        mode=mode,
        weights_sync_mode=weights_sync_mode,
        seeds=seeds,
        monitoring_params=monitoring_params,
        **config_["sampler"],
    )

    if resume is not None:
        sampler.load_checkpoint(filepath=resume)

    sampler.run()