Beispiel #1
0
 def _input(ioctx):
     # We are remote worker or we are local worker with num_workers=0:
     # Create a PolicyServerInput.
     if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0:
         return PolicyServerInput(
             ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index -
             (1 if ioctx.worker_index > 0 else 0))
     # No InputReader (PolicyServerInput) needed.
     else:
         return None
def ray_server(run='PPO', address=ADDRESS, port=PORT):
    print(ray.init(log_to_driver=False))

    connector_config = {
        "input": (lambda ioctx: PolicyServerInput(ioctx, address, port)),
        "num_workers": 0,
        "input_evaluation": [],
        "create_env_on_driver": False,
        "num_gpus": FLAGS.num_gpus,
    }

    if run == "DQN":
        trainer = DQNTrainer(env=ExternalAtari,
                             config=dict(connector_config, **CONFIG_DQN))
    elif run == "PPO":
        trainer = PPOTrainer(env=ExternalAtari,
                             config=dict(connector_config, **CONFIG_PPO))
    else:
        raise ValueError("--run must be DQN or PPO")

    i = 0
    while i < FLAGS.iter:
        i += 1
        print(pretty_print(trainer.train()))
    ray.shutdown()

    checkpoint = trainer.save("{}/ckpts".format(FLAGS.train_url.rstrip('/')))
    print("checkpoint saved at", checkpoint)
    mox.file.copy(
        os.path.join(os.path.abspath(os.path.dirname(__file__)),
                     "config.json"),
        os.path.join(FLAGS.train_url, "config.json"))
    mox.file.copy(
        os.path.join(os.path.abspath(os.path.dirname(__file__)),
                     "customize_service.py"),
        os.path.join(FLAGS.train_url, "customize_service.py"))
    mox.file.copy(os.path.join(FLAGS.data_url, "rl_config.py"),
                  os.path.join(FLAGS.train_url, "rl_config.py"))

    del trainer
Beispiel #3
0
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="DQN")
parser.add_argument("--framework",
                    type=str,
                    choices=["tf", "torch"],
                    default="tf")

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    env = "CartPole-v0"
    connector_config = {
        # Use the connector server to generate experiences.
        "input":
        (lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)),
        # Use a single worker process to run the server.
        "num_workers":
        0,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],
    }

    if args.run == "DQN":
        # Example of using DQN (supports off-policy actions).
        trainer = DQNTrainer(env=env,
                             config=dict(
                                 connector_config, **{
                                     "exploration_config": {
                                         "type": "EpsilonGreedy",
                                         "initial_epsilon": 1.0,
Beispiel #4
0
 #             # These params start off randomly drawn from a set.
 #             # "num_sgd_iter": sample_from(
 #             #     lambda spec: random.choice([5, 10, 15])),
 #             # "sgd_minibatch_size": sample_from(
 #             #     lambda spec: random.choice([10000, 25000, 50000])),
 #             # "train_batch_size": sample_from(
 #             #     lambda spec: random.choice([50000, 100000, 200000]))
 #         })
 #Start training loop.
 if True:
     trainer = CCTrainer(
         env='BlueSkySrv',
         config={
             "callbacks":
             MyCallbacks,
             "input": (lambda ioctx: PolicyServerInput(
                 ioctx, SERVER_ADDRESS, server_port)),
             'model': {
                 "custom_model": "Centralized",
                 'custom_action_dist': 'CategoricalOrdinalTFP'
                 # 'max_seq_len':20,
                 # 'lstm_use_prev_action_reward':True
             },
             "input_evaluation": [],
             "log_level":
             "DEBUG",
             'num_workers':
             0,
             'num_sgd_iter':
             10,
             'rollout_fragment_length':
             5000,  #600
Beispiel #5
0
    # Create a fake-env for the server. This env will never be used (neither
    # for sampling, nor for evaluation) and its obs/action Spaces do not
    # matter either (multi-agent config below defines Spaces per Policy).
    register_env("fake_unity", lambda c: RandomMultiAgentEnv(c))

    policies, policy_mapping_fn = \
        Unity3DEnv.get_policy_configs_for_game(args.env)

    # The entire config will be sent to connecting clients so they can
    # build their own samplers (and also Policy objects iff
    # `inference_mode=local` on clients' command line).
    config = {
        # Use the connector server to generate experiences.
        "input":
        (lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)),
        # Use a single worker process (w/ SyncSampler) to run the server.
        "num_workers":
        0,
        # Disable OPE, since the rollouts are coming from online clients.
        "input_evaluation": [],

        # Other settings.
        "train_batch_size":
        256,
        "rollout_fragment_length":
        20,
        # Multi-agent setup for the particular env.
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_mapping_fn,
Beispiel #6
0
def init_policy_server(config):
    ''' Start the policy serve r that receives (state, action, reward) batches
        and computes gradients for RL '''
    # By default, Ray will parallelize its workload. However, if you need to debug your Ray program,
    # it may be easier to do everything on a single process. You can force all Ray functions to occur
    # on a single process with local_mode=True.
    memory_limit = 1024 * 1024 * 1024 * 10  # GB
    ray.init(local_mode=config["debug_mode"], object_store_memory=memory_limit)

    trainer_config = config["trainer_config"]
    assert torch.cuda.device_count() >= trainer_config["num_gpus"]
    trainer_config["input"] = lambda ioctx: PolicyServerInput(
        ioctx, config["policy_server_host"], int(config["policy_server_port"]))

    custom_model = config.get("custom_model", None)
    if custom_model:
        register_custom_model(custom_model)
        trainer_config["model"]["custom_model"] = custom_model

    agent_obs_space, agent_action_space = \
        make_observation_space_and_action_space(config, True)
    if config["multiagent"]:
        user_obs_space, user_action_space = \
            make_observation_space_and_action_space(config, False)
        trainer_config["multiagent"] = {
            "policies": {
                # the first tuple value is None -> uses default policy
                "agent": (None, agent_obs_space, agent_action_space, {}),
                "user": (None, user_obs_space, user_action_space, {})
            },
            "policy_mapping_fn": lambda agent_id: agent_id
        }
        from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
        # Create a fake env for the server. This env will never be used (neither
        # for sampling, nor for evaluation) and its obs/action Spaces do not
        # matter either (multi-agent config above defines Spaces per Policy).
        register_env("custom_env", lambda c: RandomMultiAgentEnv(c))
    else:

        def custom_env(env_config):
            ''' Create an env stub for policy training.  Only the 
                action_space and observation_space attributes need to be defined,
                no other env functions are needed (e.g. step(), reset(), etc). '''
            env = gym.Env()
            env.action_space = agent_action_space
            env.observation_space = agent_obs_space
            return env

        register_env("custom_env", custom_env)

    trainer = SUPPORTED_TRAINER_CLASSES[config["trainer_class"]](
        env="custom_env", config=trainer_config)

    os.makedirs(config["model_checkpoint_dir"], exist_ok=True)
    # Attempt to restore from checkpoint if possible.
    latest_checkpoint = os.path.join(config["model_checkpoint_dir"],
                                     "latest_ckpt")
    if os.path.exists(latest_checkpoint):
        latest_checkpoint = open(latest_checkpoint).read()
        print("Restoring from checkpoint", latest_checkpoint)
        trainer.restore(latest_checkpoint)

    # Serving and training loop
    print("######## Training loop begins... ########")
    count = 0
    while True:
        train_log = trainer.train()
        # print(pretty_print(train_log))
        count += 1
        if count % config["checkpoint_freq"] == 0:
            checkpoint = trainer.save()
            print("#### Writing checkpoint for train iteration", count, "####")
            with open(latest_checkpoint, "w") as f:
                f.write(checkpoint)