예제 #1
0
def run(args, parser):
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "trial_resources": (
                    args.trial_resources and
                    resources_to_json(args.trial_resources)),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "repeat": args.repeat,
                "upload_dir": args.upload_dir,
            }
        }

    for exp in experiments.values():
        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

    ray.init(redis_address=args.redis_address,
             num_cpus=args.ray_num_cpus,
             num_gpus=args.ray_num_gpus)
    run_experiments(experiments,
                    scheduler=_make_scheduler(args),
                    queue_trials=args.queue_trials)
예제 #2
0
파일: train.py 프로젝트: wym42/ray
def run(args, parser):
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "resources_per_trial": (
                    args.resources_per_trial and
                    resources_to_json(args.resources_per_trial)),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "num_samples": args.num_samples,
                "upload_dir": args.upload_dir,
            }
        }

    for exp in experiments.values():
        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

    if args.ray_num_local_schedulers:
        cluster = Cluster()
        for _ in range(args.ray_num_local_schedulers):
            cluster.add_node(resources={
                "num_cpus": args.ray_num_cpus or 1,
                "num_gpus": args.ray_num_gpus or 0,
            },
                             object_store_memory=args.ray_object_store_memory,
                             redis_max_memory=args.ray_redis_max_memory)
        ray.init(redis_address=cluster.redis_address)
    else:
        ray.init(redis_address=args.redis_address,
                 object_store_memory=args.ray_object_store_memory,
                 redis_max_memory=args.ray_redis_max_memory,
                 num_cpus=args.ray_num_cpus,
                 num_gpus=args.ray_num_gpus)
    run_experiments(experiments,
                    scheduler=_make_scheduler(args),
                    queue_trials=args.queue_trials)
예제 #3
0
    help="If specified, use config options from this file. Note that this "
    "overrides any trial-specific options set via flags above.")

if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to /tmp/ray/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "resources": resources_to_json(args.resources),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "repeat": args.repeat,
                "upload_dir": args.upload_dir,
            }
        }

    for exp in experiments.values():
        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

    run_experiments(experiments,
예제 #4
0
    "--num-cpus", default=2, type=int,
    help="Number of CPUs to allocate to Ray.")
parser.add_argument(
    "--num-gpus", default=1, type=int,
    help="Number of GPUs to allocate to Ray.")
parser.add_argument(
    "--experiment-name", default="default", type=str,
    help="Name of the subdirectory under `local_dir` to put results in.")
parser.add_argument(
    "--env", default=None, type=str, help="The gym environment to use.")


ModelCatalog.register_custom_preprocessor("sc_prep", StarCraftPreprocessor)
register_env("sc2", lambda config: StarCraft(config))
register_trainable("SC_A3C", A3CAgent)


if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])

    experiments = {
                'experiment_name': {
                    "run" : 'SC_A3C',
                    "env" : 'sc2',
                    "trial_resources" : resources_to_json(args.trial_resources),
                    "config": dict(args.config, env=args.env),
                }
            }
    ray.init(redis_address=args.redis_address, num_gpus=1, num_cpus=args.num_cpus)
    run_experiments(experiments)
    
예제 #5
0
파일: train.py 프로젝트: adgirish/ray
    "overrides any trial-specific options set via flags above.")


if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "resources": resources_to_json(args.resources),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "repeat": args.repeat,
                "upload_dir": args.upload_dir,
            }
        }

    for exp in experiments.values():
        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

    ray.init(
예제 #6
0
파일: train.py 프로젝트: zxsimple/ray
if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])
    if args.config_file:
        with open(args.config_file) as f:
            experiments = yaml.load(f)
    else:
        # Note: keep this in sync with tune/config_parser.py
        experiments = {
            args.experiment_name: {  # i.e. log to ~/ray_results/default
                "run": args.run,
                "checkpoint_freq": args.checkpoint_freq,
                "local_dir": args.local_dir,
                "trial_resources": (
                    args.trial_resources and
                    resources_to_json(args.trial_resources)),
                "stop": args.stop,
                "config": dict(args.config, env=args.env),
                "restore": args.restore,
                "repeat": args.repeat,
                "upload_dir": args.upload_dir,
            }
        }

    for exp in experiments.values():
        if not exp.get("run"):
            parser.error("the following arguments are required: --run")
        if not exp.get("env") and not exp.get("config", {}).get("env"):
            parser.error("the following arguments are required: --env")

    ray.init(redis_address=args.redis_address,