def _input(ioctx): # We are remote worker or we are local worker with num_workers=0: # Create a PolicyServerInput. if ioctx.worker_index > 0 or ioctx.worker.num_workers == 0: return PolicyServerInput( ioctx, SERVER_ADDRESS, args.port + ioctx.worker_index - (1 if ioctx.worker_index > 0 else 0)) # No InputReader (PolicyServerInput) needed. else: return None
def ray_server(run='PPO', address=ADDRESS, port=PORT): print(ray.init(log_to_driver=False)) connector_config = { "input": (lambda ioctx: PolicyServerInput(ioctx, address, port)), "num_workers": 0, "input_evaluation": [], "create_env_on_driver": False, "num_gpus": FLAGS.num_gpus, } if run == "DQN": trainer = DQNTrainer(env=ExternalAtari, config=dict(connector_config, **CONFIG_DQN)) elif run == "PPO": trainer = PPOTrainer(env=ExternalAtari, config=dict(connector_config, **CONFIG_PPO)) else: raise ValueError("--run must be DQN or PPO") i = 0 while i < FLAGS.iter: i += 1 print(pretty_print(trainer.train())) ray.shutdown() checkpoint = trainer.save("{}/ckpts".format(FLAGS.train_url.rstrip('/'))) print("checkpoint saved at", checkpoint) mox.file.copy( os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.json"), os.path.join(FLAGS.train_url, "config.json")) mox.file.copy( os.path.join(os.path.abspath(os.path.dirname(__file__)), "customize_service.py"), os.path.join(FLAGS.train_url, "customize_service.py")) mox.file.copy(os.path.join(FLAGS.data_url, "rl_config.py"), os.path.join(FLAGS.train_url, "rl_config.py")) del trainer
parser = argparse.ArgumentParser() parser.add_argument("--run", type=str, default="DQN") parser.add_argument("--framework", type=str, choices=["tf", "torch"], default="tf") if __name__ == "__main__": args = parser.parse_args() ray.init() env = "CartPole-v0" connector_config = { # Use the connector server to generate experiences. "input": (lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)), # Use a single worker process to run the server. "num_workers": 0, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], } if args.run == "DQN": # Example of using DQN (supports off-policy actions). trainer = DQNTrainer(env=env, config=dict( connector_config, **{ "exploration_config": { "type": "EpsilonGreedy", "initial_epsilon": 1.0,
# # These params start off randomly drawn from a set. # # "num_sgd_iter": sample_from( # # lambda spec: random.choice([5, 10, 15])), # # "sgd_minibatch_size": sample_from( # # lambda spec: random.choice([10000, 25000, 50000])), # # "train_batch_size": sample_from( # # lambda spec: random.choice([50000, 100000, 200000])) # }) #Start training loop. if True: trainer = CCTrainer( env='BlueSkySrv', config={ "callbacks": MyCallbacks, "input": (lambda ioctx: PolicyServerInput( ioctx, SERVER_ADDRESS, server_port)), 'model': { "custom_model": "Centralized", 'custom_action_dist': 'CategoricalOrdinalTFP' # 'max_seq_len':20, # 'lstm_use_prev_action_reward':True }, "input_evaluation": [], "log_level": "DEBUG", 'num_workers': 0, 'num_sgd_iter': 10, 'rollout_fragment_length': 5000, #600
# Create a fake-env for the server. This env will never be used (neither # for sampling, nor for evaluation) and its obs/action Spaces do not # matter either (multi-agent config below defines Spaces per Policy). register_env("fake_unity", lambda c: RandomMultiAgentEnv(c)) policies, policy_mapping_fn = \ Unity3DEnv.get_policy_configs_for_game(args.env) # The entire config will be sent to connecting clients so they can # build their own samplers (and also Policy objects iff # `inference_mode=local` on clients' command line). config = { # Use the connector server to generate experiences. "input": (lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)), # Use a single worker process (w/ SyncSampler) to run the server. "num_workers": 0, # Disable OPE, since the rollouts are coming from online clients. "input_evaluation": [], # Other settings. "train_batch_size": 256, "rollout_fragment_length": 20, # Multi-agent setup for the particular env. "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping_fn,
def init_policy_server(config): ''' Start the policy serve r that receives (state, action, reward) batches and computes gradients for RL ''' # By default, Ray will parallelize its workload. However, if you need to debug your Ray program, # it may be easier to do everything on a single process. You can force all Ray functions to occur # on a single process with local_mode=True. memory_limit = 1024 * 1024 * 1024 * 10 # GB ray.init(local_mode=config["debug_mode"], object_store_memory=memory_limit) trainer_config = config["trainer_config"] assert torch.cuda.device_count() >= trainer_config["num_gpus"] trainer_config["input"] = lambda ioctx: PolicyServerInput( ioctx, config["policy_server_host"], int(config["policy_server_port"])) custom_model = config.get("custom_model", None) if custom_model: register_custom_model(custom_model) trainer_config["model"]["custom_model"] = custom_model agent_obs_space, agent_action_space = \ make_observation_space_and_action_space(config, True) if config["multiagent"]: user_obs_space, user_action_space = \ make_observation_space_and_action_space(config, False) trainer_config["multiagent"] = { "policies": { # the first tuple value is None -> uses default policy "agent": (None, agent_obs_space, agent_action_space, {}), "user": (None, user_obs_space, user_action_space, {}) }, "policy_mapping_fn": lambda agent_id: agent_id } from ray.rllib.examples.env.random_env import RandomMultiAgentEnv # Create a fake env for the server. This env will never be used (neither # for sampling, nor for evaluation) and its obs/action Spaces do not # matter either (multi-agent config above defines Spaces per Policy). register_env("custom_env", lambda c: RandomMultiAgentEnv(c)) else: def custom_env(env_config): ''' Create an env stub for policy training. Only the action_space and observation_space attributes need to be defined, no other env functions are needed (e.g. step(), reset(), etc). ''' env = gym.Env() env.action_space = agent_action_space env.observation_space = agent_obs_space return env register_env("custom_env", custom_env) trainer = SUPPORTED_TRAINER_CLASSES[config["trainer_class"]]( env="custom_env", config=trainer_config) os.makedirs(config["model_checkpoint_dir"], exist_ok=True) # Attempt to restore from checkpoint if possible. latest_checkpoint = os.path.join(config["model_checkpoint_dir"], "latest_ckpt") if os.path.exists(latest_checkpoint): latest_checkpoint = open(latest_checkpoint).read() print("Restoring from checkpoint", latest_checkpoint) trainer.restore(latest_checkpoint) # Serving and training loop print("######## Training loop begins... ########") count = 0 while True: train_log = trainer.train() # print(pretty_print(train_log)) count += 1 if count % config["checkpoint_freq"] == 0: checkpoint = trainer.save() print("#### Writing checkpoint for train iteration", count, "####") with open(latest_checkpoint, "w") as f: f.write(checkpoint)