Ejemplo n.º 1
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)

    os.makedirs(args.logdir, exist_ok=True)
    save_config(config=config, logdir=args.logdir)
    if args.expdir is not None:
        modules = prepare_modules(  # noqa: F841
            expdir=args.expdir, dump_dir=args.logdir)

    algorithm = Registry.get_fn("algorithm", args.algorithm)
    algorithm_kwargs = algorithm.prepare_for_trainer(config)

    redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000))
    redis_prefix = config.get("redis", {}).get("prefix", "")

    pprint(config["trainer"])
    pprint(algorithm_kwargs)

    trainer = Trainer(**config["trainer"],
                      **algorithm_kwargs,
                      logdir=args.logdir,
                      redis_server=redis_server,
                      redis_prefix=redis_prefix)

    pprint(trainer)

    def on_exit():
        for p in trainer.get_processes():
            p.terminate()

    atexit.register(on_exit)

    trainer.run()
Ejemplo n.º 2
0
    def parse_args_uargs(self):
        args, config = parse_args_uargs(self.args, [])
        config = merge_dicts_smart(config, self.grid_config)
        config = merge_dicts_smart(config, self.params)

        if self.distr_info:
            self.set_dist_env(config)
        return args, config
Ejemplo n.º 3
0
    def parse_args_uargs(self):
        args, config = parse_args_uargs(self.args, [])
        config = merge_dicts_smart(config, self.grid_config)

        os.environ['CUDA_VISIBLE_DEVICES'] = self.task.gpu_assigned or ''

        if self.distr_info:
            self.set_dist_env(config)
        return args, config
Ejemplo n.º 4
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)
    set_global_seeds(config.get("seed", 42))

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    experiment = Experiment(config)
    runner = Runner()
    dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment, check=args.check)
Ejemplo n.º 5
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    experiment = Experiment(config)
    runner = Runner()

    if experiment.logdir is not None:
        dump_config(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment, check=args.check)
Ejemplo n.º 6
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    experiment = Experiment(config)
    runner = Runner()

    if experiment.logdir is not None:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    runner.run_experiment(experiment, check=args.check)
Ejemplo n.º 7
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_config(config, args.logdir, args.configs)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
        weights_sync_mode = "critic" if env.discrete_actions else "actor"
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True
        weights_sync_mode = "actor"
    else:
        # @TODO: add registry for algorithms, trainers, samplers
        raise NotImplementedError()

    db_server = DATABASES.get_from_params(
        **config.get("db", {}), sync_epoch=sync_epoch
    )

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        algorithm.load_checkpoint(filepath=args.resume)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        logdir=args.logdir,
        weights_sync_mode=weights_sync_mode,
        **config["trainer"],
    )

    trainer.run()
Ejemplo n.º 8
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)

    if args.logdir is not None:
        os.makedirs(args.logdir, exist_ok=True)
        dump_config(args.configs, args.logdir)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    algorithm_name = config["algorithm"].pop("algorithm")
    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        trainer_fn = OffpolicyTrainer
        sync_epoch = False
    else:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        trainer_fn = OnpolicyTrainer
        sync_epoch = True

    db_server = DATABASES.get_from_params(**config.get("db", {}),
                                          sync_epoch=sync_epoch)

    env = ENVIRONMENTS.get_from_params(**config["environment"])

    algorithm_fn = ALGORITHMS.get(algorithm_name)
    algorithm = algorithm_fn.prepare_for_trainer(env_spec=env, config=config)

    if args.resume is not None:
        algorithm.load_checkpoint(filepath=args.resume)

    trainer = trainer_fn(
        algorithm=algorithm,
        env_spec=env,
        db_server=db_server,
        **config["trainer"],
        logdir=args.logdir,
    )

    def on_exit():
        for p in trainer.get_processes():
            p.terminate()

    atexit.register(on_exit)

    trainer.run()
Ejemplo n.º 9
0
def main(args, unknown_args):
    """Run the ``catalyst-dl run`` script"""
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seed(args.seed)
    prepare_cudnn(args.deterministic, args.benchmark)

    Experiment, Runner = import_experiment_and_runner(Path(args.expdir))

    runner_params = config.pop("runner_params", {}) or {}
    experiment = Experiment(config)
    runner = Runner(**runner_params)

    if experiment.logdir is not None:
        dump_environment(config, experiment.logdir, args.configs)
        dump_code(args.expdir, experiment.logdir)

    check_run = safitty.get(config, "args", "check", default=False)
    runner.run_experiment(experiment, check=check_run)
Ejemplo n.º 10
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)

    os.makedirs(args.logdir, exist_ok=True)
    save_config(config=config, logdir=args.logdir)
    if args.expdir is not None:
        modules = prepare_modules(  # noqa: F841
            expdir=args.expdir, dump_dir=args.logdir)

    algorithm = Registry.get_fn("algorithm", args.algorithm)
    if args.environment is not None:
        # @TODO: remove this hack
        # come on, just refactor whole rl
        environment_fn = Registry.get_fn("environment", args.environment)
        env = environment_fn(**config["env"])
        config["shared"]["observation_size"] = env.observation_shape[0]
        config["shared"]["action_size"] = env.action_shape[0]
        del env
    algorithm_kwargs = algorithm.prepare_for_trainer(config)

    redis_server = StrictRedis(port=config.get("redis", {}).get("port", 12000))
    redis_prefix = config.get("redis", {}).get("prefix", "")

    pprint(config["trainer"])
    pprint(algorithm_kwargs)

    trainer = Trainer(**config["trainer"],
                      **algorithm_kwargs,
                      logdir=args.logdir,
                      redis_server=redis_server,
                      redis_prefix=redis_prefix)

    pprint(trainer)

    def on_exit():
        for p in trainer.get_processes():
            p.terminate()

    atexit.register(on_exit)

    trainer.run()
Ejemplo n.º 11
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)
    set_global_seeds(args.seed)

    modules = prepare_modules(expdir=args.expdir)

    model = Registry.get_model(**config["model_params"])
    datasource = modules["data"].DataSource()
    data_params = config.get("data_params", {}) or {}
    loaders = datasource.prepare_loaders(mode="infer",
                                         n_workers=args.workers,
                                         batch_size=args.batch_size,
                                         **data_params)

    runner = modules["model"].ModelRunner(model=model)
    callbacks_params = config.get("callbacks_params", {}) or {}
    callbacks = runner.prepare_callbacks(mode="infer",
                                         resume=args.resume,
                                         out_prefix=args.out_prefix,
                                         **callbacks_params)
    runner.infer(loaders=loaders, callbacks=callbacks, verbose=args.verbose)
Ejemplo n.º 12
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args, dump_config=True)
    set_global_seeds(args.seed)

    assert args.baselogdir is not None or args.logdir is not None

    if args.logdir is None:
        modules_ = prepare_modules(expdir=args.expdir)
        logdir = modules_["model"].prepare_logdir(config=config)
        args.logdir = str(pathlib.Path(args.baselogdir).joinpath(logdir))

    os.makedirs(args.logdir, exist_ok=True)
    save_config(config=config, logdir=args.logdir)
    modules = prepare_modules(expdir=args.expdir, dump_dir=args.logdir)

    model = Registry.get_model(**config["model_params"])
    datasource = modules["data"].DataSource()

    runner = modules["model"].ModelRunner(model=model)
    runner.train_stages(datasource=datasource,
                        args=args,
                        stages_config=config["stages"],
                        verbose=args.verbose)
Ejemplo n.º 13
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)

    args.vis = args.vis or 0
    args.infer = args.infer or 0
    args.valid = args.valid or 0
    args.train = args.train or 0

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    environment_name = config["environment"].pop("environment")
    environment_fn = ENVIRONMENTS.get(environment_name)

    algorithm_name = config["algorithm"].pop("algorithm")

    if algorithm_name in OFFPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = OFFPOLICY_ALGORITHMS
        sync_epoch = False
    elif algorithm_name in ONPOLICY_ALGORITHMS_NAMES:
        ALGORITHMS = ONPOLICY_ALGORITHMS
        sync_epoch = True
    else:
        raise NotImplementedError()

    algorithm_fn = ALGORITHMS.get(algorithm_name)

    processes = []
    sampler_id = args.sampler_id

    def on_exit():
        for p in processes:
            p.terminate()

    atexit.register(on_exit)

    params = dict(seed=args.seed,
                  logdir=args.logdir,
                  algorithm_fn=algorithm_fn,
                  environment_fn=environment_fn,
                  config=config,
                  resume=args.resume,
                  db=args.db,
                  sync_epoch=sync_epoch)

    if args.check:
        mode = "train"
        mode = "valid" if (args.valid is not None and args.valid > 0) else mode
        mode = "infer" if (args.infer is not None and args.infer > 0) else mode
        params_ = dict(visualize=(args.vis is not None and args.vis > 0),
                       mode=mode,
                       id=sampler_id)
        run_sampler(**params, **params_)

    for i in range(args.vis):
        params_ = dict(visualize=True,
                       mode="infer",
                       id=sampler_id,
                       exploration_power=0.0)
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(args.infer):
        params_ = dict(visualize=False,
                       mode="infer",
                       id=sampler_id,
                       exploration_power=0.0)
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(args.valid):
        params_ = dict(visualize=False,
                       mode="valid",
                       id=sampler_id,
                       exploration_power=0.0)
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for i in range(1, args.train + 1):
        exploration_power = i / args.train
        params_ = dict(visualize=False,
                       mode="train",
                       id=sampler_id,
                       exploration_power=exploration_power)
        p = mp.Process(
            target=run_sampler,
            kwargs=dict(**params, **params_),
            daemon=args.daemon,
        )
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(args.run_delay)

    for p in processes:
        p.join()
Ejemplo n.º 14
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)

    if args.expdir is not None:
        module = import_module(expdir=args.expdir)  # noqa: F841

    environment_name = config["environment"].pop("environment")
    environment_fn = ENVIRONMENTS.get(environment_name)

    algorithm_name = config["algorithm"].pop("algorithm")
    algorithm_fn = ALGORITHMS.get(algorithm_name)

    processes = []
    sampler_id = 0

    def on_exit():
        for p in processes:
            p.terminate()

    atexit.register(on_exit)

    params = dict(
        seed=args.seed,
        logdir=args.logdir,
        algorithm_fn=algorithm_fn,
        environment_fn=environment_fn,
        config=config,
        resume=args.resume,
        db=args.db
    )

    if args.check:
        params_ = dict(
            vis=False,
            infer=False,
            id=sampler_id
        )
        run_sampler(**params, **params_)

    for i in range(args.vis):
        params_ = dict(
            vis=True,
            infer=True,
            id=sampler_id,
            exploration_power=0.0
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(STEP_DELAY)

    for i in range(args.infer):
        params_ = dict(
            vis=False,
            infer=True,
            id=sampler_id,
            exploration_power=0.0
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(STEP_DELAY)

    for i in range(1, args.train + 1):
        exploration_power = i / args.train
        params_ = dict(
            vis=False,
            infer=False,
            id=sampler_id,
            exploration_power=exploration_power
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1
        time.sleep(STEP_DELAY)

    for p in processes:
        p.join()
Ejemplo n.º 15
0
def main(args, unknown_args):
    args, config = parse_args_uargs(args, unknown_args)

    os.makedirs(args.logdir, exist_ok=True)
    save_config(config=config, logdir=args.logdir)
    if args.expdir is not None:
        modules = prepare_modules(  # noqa: F841
            expdir=args.expdir, dump_dir=args.logdir)

    algorithm = Registry.get_fn("algorithm", args.algorithm)
    environment = Registry.get_fn("environment", args.environment)

    processes = []
    sampler_id = 0

    def on_exit():
        for p in processes:
            p.terminate()

    atexit.register(on_exit)

    params = dict(logdir=args.logdir,
                  algorithm=algorithm,
                  environment=environment,
                  config=config,
                  resume=args.resume,
                  redis=args.redis)

    if args.debug:
        params_ = dict(
            vis=False,
            infer=False,
            action_noise=0.5,
            param_noise=0.5,
            action_noise_prob=args.action_noise_prob,
            param_noise_prob=args.param_noise_prob,
            id=sampler_id,
        )
        run_sampler(**params, **params_)

    for i in range(args.vis):
        params_ = dict(
            vis=False,
            infer=False,
            action_noise_prob=0,
            param_noise_prob=0,
            id=sampler_id,
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1

    for i in range(args.infer):
        params_ = dict(
            vis=False,
            infer=True,
            action_noise_prob=0,
            param_noise_prob=0,
            id=sampler_id,
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1

    for i in range(1, args.train + 1):
        action_noise = args.max_action_noise * i / args.train \
            if args.max_action_noise is not None \
            else None
        param_noise = args.max_param_noise * i / args.train \
            if args.max_param_noise is not None \
            else None
        params_ = dict(
            vis=False,
            infer=False,
            action_noise=action_noise,
            param_noise=param_noise,
            action_noise_prob=args.action_noise_prob,
            param_noise_prob=args.param_noise_prob,
            id=sampler_id,
        )
        p = mp.Process(target=run_sampler, kwargs=dict(**params, **params_))
        p.start()
        processes.append(p)
        sampler_id += 1

    for p in processes:
        p.join()