def test_policy_and_vf_weight_syncing(self): """ Tests weight synchronization with a local agent and a remote worker. """ # First, create a local agent env_spec = dict(type="openai", gym_env="CartPole-v0") env = Environment.from_spec(env_spec) agent_config = config_from_path("configs/sync_batch_ppo_cartpole.json") ray_spec = agent_config["execution_spec"].pop("ray_spec") local_agent = Agent.from_spec(agent_config, state_space=env.state_space, action_space=env.action_space) ray_spec["worker_spec"]["worker_sample_size"] = 50 # Create a remote worker with the same agent config. worker = RayPolicyWorker.as_remote().remote(agent_config, ray_spec["worker_spec"], self.env_spec, auto_build=True) # This imitates the initial executor sync without ray.put weights = RayWeight(local_agent.get_weights()) print('Weight type in init sync = {}'.format(type(weights))) print("Weights = ", weights) worker.set_weights.remote(weights) print('Init weight sync successful.') # Replicate worker syncing steps as done in e.g. Ape-X executor: weights = RayWeight(local_agent.get_weights()) print('Weight type returned by ray put = {}'.format(type(weights))) print(weights) ret = worker.set_weights.remote(weights) ray.wait([ret]) print('Object store weight sync successful.')
def test_worker_weight_syncing(self): """ Tests weight synchronization with a local agent and a remote worker. """ # First, create a local agent env_spec = dict( type="openai", gym_env="PongNoFrameskip-v4", # The frameskip in the agent config will trigger worker skips, this # is used for internal env. frameskip=4, max_num_noops=30, episodic_life=True) env = Environment.from_spec(env_spec) agent_config = config_from_path("configs/ray_apex_for_pong.json") # Remove unneeded apex params. if "apex_replay_spec" in agent_config: agent_config.pop("apex_replay_spec") ray_spec = agent_config["execution_spec"].pop("ray_spec") local_agent = Agent.from_spec(agent_config, state_space=env.state_space, action_space=env.action_space) ray_spec["worker_spec"]["worker_sample_size"] = 50 # Create a remote worker with the same agent config. worker = RayValueWorker.as_remote().remote(agent_config, ray_spec["worker_spec"], env_spec, auto_build=True) # This imitates the initial executor sync without ray.put weights = RayWeight(local_agent.get_weights()) print('Weight type in init sync = {}'.format(type(weights))) ret = worker.set_weights.remote(weights) ray.wait([ret]) print('Init weight sync successful.') # Replicate worker syncing steps as done in e.g. Ape-X executor: weights = RayWeight(local_agent.get_weights()) print('Weight type returned by ray put = {}'.format(type(weights))) ret = worker.set_weights.remote(weights) ray.wait([ret]) print('Object store weight sync successful.')
def _execute_step(self): """ Executes a workload on Ray. The main loop performs the following steps until the specified number of steps or episodes is finished: - Sync weights to policy workers. - Schedule a set of samples - Wait until enough samples tasks are complete to form an update batch - Merge samples - Perform local update(s) """ # Env steps done during this rollout. env_steps = 0 # 1. Sync local learners weights to remote workers. weights = ray.put(RayWeight(self.local_agent.get_weights())) for ray_worker in self.ray_env_sample_workers: ray_worker.set_weights.remote(weights) # 2. Schedule samples and fetch results from RayWorkers. sample_batches = [] num_samples = 0 while num_samples < self.update_batch_size: batches = ray.get([ worker.execute_and_get_timesteps.remote( self.worker_sample_size) for worker in self.ray_env_sample_workers ]) # Each batch has exactly worker_sample_size length. num_samples += len(batches) * self.worker_sample_size sample_batches.extend(batches) env_steps += num_samples # 3. Merge samples rewards = [] for sample in sample_batches: if len(sample.metrics["last_rewards"]) > 0: rewards.extend(sample.metrics["last_rewards"]) batch = merge_samples(sample_batches, decompress=self.compress_states) # 4. Update from merged batch. self.local_agent.update(batch, apply_postprocessing=False) return env_steps, 1, { "discarded": 0, "queue_inserts": 0, "rewards": rewards }
def init_tasks(self): # Start learner thread. self.update_worker.start() # Prioritized replay sampling tasks via RayAgents. for ray_memory in self.ray_local_replay_memories: for _ in range(self.replay_sampling_task_depth): # This initializes remote tasks to sample from the prioritized replay memories of each worker. self.prioritized_replay_tasks.add_task(ray_memory, ray_memory.get_batch.remote()) # Env interaction tasks via RayWorkers which each # have a local agent. weights = RayWeight(self.local_agent.get_weights()) for ray_worker in self.ray_env_sample_workers: ray_worker.set_weights.remote(weights) self.steps_since_weights_synced[ray_worker] = 0 self.logger.info("Synced worker {} weights, initializing sample tasks.".format( self.worker_ids[ray_worker])) for _ in range(self.env_interaction_task_depth): self.env_sample_tasks.add_task(ray_worker, ray_worker.execute_and_get_with_count.remote())
def _execute_step(self): """ Executes a workload on Ray. The main loop performs the following steps until the specified number of steps or episodes is finished: - Retrieve sample batches via Ray from remote workers - Insert these into the local memory - Have a separate learn thread sample batches from the memory and compute updates - Sync weights to the shared model so remot eworkers can update their weights. """ # Env steps done during this rollout. env_steps = 0 update_steps = 0 discarded = 0 queue_inserts = 0 weights = None # 1. Fetch results from RayWorkers. completed_sample_tasks = list(self.env_sample_tasks.get_completed()) sample_batch_sizes = ray.get( [task[1][1] for task in completed_sample_tasks]) for i, (ray_worker, (env_sample_obj_id, sample_size)) in enumerate(completed_sample_tasks): # Randomly add env sample to a local replay actor. random.choice(self.ray_local_replay_memories).observe.remote( env_sample_obj_id) sample_steps = sample_batch_sizes[i] env_steps += sample_steps self.steps_since_weights_synced[ray_worker] += sample_steps if self.steps_since_weights_synced[ ray_worker] >= self.weight_sync_steps: if weights is None or self.update_worker.update_done: self.update_worker.update_done = False weights = ray.put(RayWeight( self.local_agent.get_weights())) # self.logger.debug("Syncing weights for worker {}".format(self.worker_ids[ray_worker])) # self.logger.debug("Weights type: {}, weights = {}".format(type(weights), weights)) ray_worker.set_weights.remote(weights) self.weight_syncs_executed += 1 self.steps_since_weights_synced[ray_worker] = 0 # Reschedule environment samples. self.env_sample_tasks.add_task( ray_worker, ray_worker.execute_and_get_with_count.remote()) # 2. Fetch completed replay priority sampling task, move to worker, reschedule. for ray_memory, replay_remote_task in self.prioritized_replay_tasks.get_completed( ): # Immediately schedule new batch sampling tasks on these workers. self.prioritized_replay_tasks.add_task( ray_memory, ray_memory.get_batch.remote()) # Retrieve results via id. # self.logger.info("replay task obj id {}".format(replay_remote_task)) if self.discard_queued_samples and self.update_worker.input_queue.full( ): discarded += 1 else: sampled_batch = ray.get(object_ids=replay_remote_task) # Pass to the agent doing the actual updates. # The ray worker is passed along because we need to update its priorities later in the subsequent # task (see loop below). # Copy due to memory leaks in Ray, see https://github.com/ray-project/ray/pull/3484/ self.update_worker.input_queue.put((ray_memory, sampled_batch and sampled_batch.copy())) queue_inserts += 1 # 3. Update priorities on priority sampling workers using loss values produced by update worker. while not self.update_worker.output_queue.empty(): ray_memory, indices, loss_per_item = self.update_worker.output_queue.get( ) # self.logger.info('indices = {}'.format(batch["indices"])) # self.logger.info('loss = {}'.format(loss_per_item)) ray_memory.update_priorities.remote(indices, loss_per_item) # len of loss per item is update count. update_steps += len(indices) return env_steps, update_steps, discarded, queue_inserts