예제 #1
0
    def parse_args_uargs(self):
        args, config = parse_args_uargs(self.args, [])
        config = merge_dicts_smart(config, self.grid_config)
        config = merge_dicts_smart(config, self.params)

        if self.distr_info:
            self.set_dist_env(config)
        return args, config
예제 #2
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_environment(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841
        if args.logdir is not None:
            dump_code(args.expdir, args.logdir)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        checkpoint = utils.load_checkpoint(filepath=args.resume)
        checkpoint = utils.any2device(checkpoint, utils.get_device())
        algorithm.unpack_checkpoint(
            checkpoint=checkpoint,
            with_optimizer=False
        )

    monitoring_params = config.get("monitoring_params", None)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        monitoring_params=monitoring_params,
        **config["trainer"],
    )

    trainer.run()
예제 #3
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_config(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
        weights_sync_mode = "critic" if env.discrete_actions else "actor"
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
        weights_sync_mode = "actor"
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    # if args.resume is not None:
    #     algorithm.load_checkpoint(filepath=args.resume)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        weights_sync_mode=weights_sync_mode,
        **config["trainer"],
    )

    trainer.run()
예제 #4
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    args.vis = args.vis or 0
    args.infer = args.infer or 0
    args.valid = args.valid or 0
    args.train = args.train or 0

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    environment_name = config["environment"].pop("environment")
    environment_fn = ENVIRONMENTS.get(environment_name)

    algorithm_name = config["algorithm"].pop("algorithm")

    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        sync_epoch = True
    else:
        raise NotImplementedError()

    algorithm_fn = ALGORITHMS.get(algorithm_name)

    processes = []
    sampler_id = args.sampler_id

    def on_exit():
        for p in processes:
            p.terminate()

    atexit.register(on_exit)

    params = dict(
        seed=args.seed,
        logdir=args.logdir,
        algorithm_fn=algorithm_fn,
        environment_fn=environment_fn,
        config=config,
        resume=args.resume,
        db=args.db,
        sync_epoch=sync_epoch
    )

    if args.check:
        mode = "train"
        mode = "valid" if (args.valid is not None and args.valid > 0) else mode
        mode = "infer" if (args.infer is not None and args.infer > 0) else mode
        params_ = dict(
            visualize=(args.vis is not None and args.vis > 0),
            mode=mode,
            id=sampler_id
        )
        run_sampler(**params, **params_)
        return

    for i in range(args.vis):
        params_ = dict(
            visualize=True, mode="infer", id=sampler_id, exploration_power=0.0
        )
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(args.infer):
        params_ = dict(
            visualize=False,
            mode="infer",
            id=sampler_id,
            exploration_power=0.0
        )
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(args.valid):
        params_ = dict(
            visualize=False,
            mode="valid",
            id=sampler_id,
            exploration_power=0.0
        )
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(1, args.train + 1):
        exploration_power = i / args.train
        params_ = dict(
            visualize=False,
            mode="train",
            id=sampler_id,
            exploration_power=exploration_power
        )
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for p in processes:
        p.join()