def test_writing(): config = {'root_dir': '../data/wiki_test'} env = Writing(config) class TestAgent(object): def __init__(self): self.msg_queue = [1] def push_messages(self, msg): self.msg_queue.append(msg) def pull_messages(self): return self.msg_queue ray.init() agent = ray.remote(TestAgent).remote() agent_id = env.register_agent(agent) test_msg = 'Hi there' res = env.send_messages(agent_id, test_msg) ray_get_and_free(res) assert ray.get(agent.pull_messages.remote()).pop() == test_msg obs = env.reset() print(obs) results = env.step('I want to send a message') print(env.write) print(results)
def _do_live_policy_checkpoint(trainer, training_iteration): local_train_policy = trainer.workers.local_worker( ).policy_map[TRAIN_POLICY] checkpoints_dir = os.path.join(experiment_save_dir, "policy_checkpoints") checkpoint_name = f"policy_{trainer.claimed_policy_num}_{datetime_str()}_iter_{training_iteration}.dill" checkpoint_save_path = os.path.join(checkpoints_dir, checkpoint_name) local_train_policy.save_model_weights( save_file_path=checkpoint_save_path, remove_scope_prefix=TRAIN_POLICY) policy_key = os.path.join(base_experiment_name, full_experiment_name, "policy_checkpoints", checkpoint_name) storage_client = connect_storage_client() upload_file(storage_client=storage_client, bucket_name=BUCKET_NAME, object_key=policy_key, local_source_path=checkpoint_save_path) locks_checkpoint_name = f"dch_population_checkpoint_{datetime_str()}" ray_get_and_free( trainer.live_table_tracker.set_latest_key_for_claimed_policy. remote( new_key=policy_key, request_locks_checkpoint_with_name=locks_checkpoint_name))
def claim_new_active_policy_after_trainer_init_callback(trainer): def set_train_policy_warmup_target_entropy_proportion(worker): worker.policy_map[TRAIN_POLICY].set_target_entropy_proportion( PIPELINE_WARMUP_ENTROPY_TARGET_PROPORTION) trainer.workers.foreach_worker( set_train_policy_warmup_target_entropy_proportion) trainer.storage_client = connect_storage_client() logger.info("Initializing trainer manager interface") trainer.manager_interface = LearnerManagerInterface( server_host=MANAGER_SERVER_HOST, port=MANAGER_PORT, worker_id=full_experiment_name, storage_client=trainer.storage_client, minio_bucket_name=BUCKET_NAME) trainer.live_table_tracker = LivePolicyPayoffTracker.remote( minio_endpoint=MINIO_ENDPOINT, minio_access_key=MINIO_ACCESS_KEY, minio_secret_key=MINIO_SECRET_KEY, minio_bucket=BUCKET_NAME, manager_host=MANAGER_SERVER_HOST, manager_port=MANAGER_PORT, lock_server_host=LOCK_SERVER_HOST, lock_server_port=LOCK_SERVER_PORT, worker_id=full_experiment_name, policy_class_name=TRAIN_POLICY_CLASS.__name__, policy_config_key=TRAIN_POLICY_MODEL_CONFIG_KEY, provide_payoff_barrier_sync= not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS) trainer.claimed_policy_num = ray_get_and_free( trainer.live_table_tracker.get_claimed_policy_num.remote()) trainer.are_all_lower_policies_finished = False trainer.payoff_table_needs_update_started = False trainer.payoff_table = None _do_live_policy_checkpoint(trainer=trainer, training_iteration=0) if not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS: # wait for all other learners to also reach this point before continuing ray_get_and_free(trainer.live_table_tracker. wait_at_barrier_for_other_learners.remote()) trainer.new_payoff_table_promise = trainer.live_table_tracker.get_live_payoff_table_dill_pickled.remote( first_wait_for_n_seconds=2) _process_new_live_payoff_table_result_if_ready( trainer=trainer, block_until_result_is_ready=True) if INIT_FROM_POPULATION: init_train_policy_weights_from_static_policy_distribution_after_trainer_init_callback( trainer=trainer) else: print( colored( f"Policy {trainer.claimed_policy_num}: (Initializing train policy to random)", "white"))
def _step(self): sample_timesteps, train_timesteps = 0, 0 weights = None with self.timers["sample_processing"]: completed = list(self.sample_tasks.completed()) counts = ray_get_and_free([c[1][1] for c in completed]) for i, (ev, (sample_batch, count)) in enumerate(completed): sample_timesteps += counts[i] # Send the data to the replay buffer random.choice( self.replay_actors).add_batch.remote(sample_batch) # Update weights if needed self.steps_since_update[ev] += counts[i] if self.steps_since_update[ev] >= self.max_weight_sync_delay: # Note that it's important to pull new weights once # updated to avoid excessive correlation between actors if weights is None or self.learner.weights_updated: self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 # Kick off another sample request self.sample_tasks.add(ev, ev.sample_with_count.remote()) # added for dynamic experience replay if self.dynamic_experience_replay and random.random() < 0.0001: random.choice(self.replay_actors).update_demos.remote() with self.timers["replay_processing"]: for ra, replay in self.replay_tasks.completed(): self.replay_tasks.add(ra, ra.replay.remote()) if self.learner.inqueue.full(): self.num_samples_dropped += 1 else: with self.timers["get_samples"]: samples = ray_get_and_free(replay) # Defensive copy against plasma crashes, see #2610 #3452 self.learner.inqueue.put((ra, samples and samples.copy())) with self.timers["update_priorities"]: while not self.learner.outqueue.empty(): ra, prio_dict, count = self.learner.outqueue.get() ra.update_priorities.remote(prio_dict) train_timesteps += count return sample_timesteps, train_timesteps
def _try_recover(self): """Try to identify and blacklist any unhealthy workers. This method is called after an unexpected remote error is encountered from a worker. It issues check requests to all current workers and blacklists any that respond with error. If no healthy workers remain, an error is raised. """ if (not self._has_policy_optimizer() and not hasattr(self, "execution_plan")): raise NotImplementedError( "Recovery is not supported for this algorithm") if self._has_policy_optimizer(): workers = self.optimizer.workers else: assert hasattr(self, "execution_plan") workers = self.workers logger.info("Health checking all workers...") checks = [] for ev in workers.remote_workers(): _, obj_id = ev.sample_with_count.remote() checks.append(obj_id) healthy_workers = [] for i, obj_id in enumerate(checks): w = workers.remote_workers()[i] try: ray_get_and_free(obj_id) healthy_workers.append(w) logger.info("Worker {} looks healthy".format(i + 1)) except RayError: logger.exception("Blacklisting worker {}".format(i + 1)) try: w.__ray_terminate__.remote() except Exception: logger.exception("Error terminating unhealthy worker") if len(healthy_workers) < 1: raise RuntimeError( "Not enough healthy workers remain to continue.") if self._has_policy_optimizer(): self.optimizer.reset(healthy_workers) else: assert hasattr(self, "execution_plan") logger.warning("Recreating execution plan after failure") workers.reset(healthy_workers) self.train_exec_impl = self.execution_plan(workers, self.config)
def collect_samples(agents, sample_batch_size, num_envs_per_worker, train_batch_size): """Collects at least train_batch_size samples, never discarding any.""" num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while agent_dict: [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) next_sample = ray_get_and_free(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) # Only launch more tasks if we don't already have enough pending pending = len(agent_dict) * sample_batch_size * num_envs_per_worker if num_timesteps_so_far + pending < train_batch_size: fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent return SampleBatch.concat_samples(trajectories)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray_get_and_free([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) self.learner_stats = get_learner_stats(fetches) if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats
def foreach_worker(self, func): """Apply the given function to each worker instance.""" local_result = [func(self.local_worker())] remote_results = ray_get_and_free( [w.apply.remote(func) for w in self.remote_workers()]) return local_result + remote_results
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.workers.remote_workers(): samples.extend( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count) fetches = {} with self.grad_timer: for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max( 1e-4, value.std()) batch[field] = standardized for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) for minibatch in self._minibatches(batch): batch_fetches = ( self.workers.local_worker().learn_on_batch( MultiAgentBatch({policy_id: minibatch}, minibatch.count)))[policy_id] for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.grad_timer.push_units_processed(samples.count) if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: self.learner_stats = fetches[DEFAULT_POLICY_ID] else: self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return self.learner_stats
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.workers.remote_workers(): samples.extend( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # Unfortunate to have to hack it like this, but not sure how else to do it. # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization. # These have to be added prior to the policy sgd. samples["phase"] = np.zeros(samples.count) with self.grad_timer: fetches = do_minibatch_sgd(samples, self.policies, self.workers.local_worker(), self.num_sgd_iter, self.sgd_minibatch_size, self.standardize_fields) self.grad_timer.push_units_processed(samples.count) if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches: self.learner_stats = fetches[DEFAULT_POLICY_ID] else: self.learner_stats = fetches self.num_steps_sampled += samples.count self.num_steps_trained += samples.count if self.num_steps_sampled > self.aux_loss_start_after_num_steps: # Add samples to the memory to be provided to the aux loss. self._remove_unnecessary_data(samples) self.memory.append(samples) # Optionally run the aux optimization. if len(self.memory) >= self.aux_loss_every_k: samples = SampleBatch.concat_samples(self.memory) self._add_policy_logits(samples) # Ones indicate aux phase. samples["phase"] = np.ones_like(samples["phase"]) do_minibatch_sgd(samples, self.policies, self.workers.local_worker(), self.aux_loss_num_sgd_iter, self.sgd_minibatch_size, []) self.memory = [] return self.learner_stats
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batches = ray_get_and_free( [e.sample.remote() for e in self.remote_evaluators]) else: batches = [self.local_evaluator.sample()] # Handle everything as if multiagent tmp = [] for batch in batches: if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) tmp.append(batch) batches = tmp for batch in batches: self.replay_buffer.append(batch) self.num_steps_sampled += batch.count self.buffer_size += batch.count while self.buffer_size > self.max_buffer_size: evicted = self.replay_buffer.pop(0) self.buffer_size -= evicted.count if self.num_steps_sampled >= self.replay_starts: return self._optimize() else: return {}
def collect_samples_straggler_mitigation(agents, train_batch_size): """Collects at least train_batch_size samples. This is the legacy behavior as of 0.6, and launches extra sample tasks to potentially improve performance but can result in many wasted samples. """ num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < train_batch_size: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray_get_and_free(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) logger.info("Discarding {} sample tasks".format(len(agent_dict))) return SampleBatch.concat_samples(trajectories)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray_get_and_free( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def stats(self): replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote( self.debug)) timing = { "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3) for k in self.timers } timing["learner_grad_time_ms"] = round( 1000 * self.learner.grad_timer.mean, 3) timing["learner_dequeue_time_ms"] = round( 1000 * self.learner.queue_timer.mean, 3) stats = { "sample_throughput": round(self.timers["sample"].mean_throughput, 3), "train_throughput": round(self.timers["train"].mean_throughput, 3), "num_weight_syncs": self.num_weight_syncs, "num_samples_dropped": self.num_samples_dropped, "learner_queue": self.learner.learner_queue_size.stats(), "replay_shard_0": replay_stats, } debug_stats = { "timing_breakdown": timing, "pending_sample_tasks": self.sample_tasks.count, "pending_replay_tasks": self.replay_tasks.count, } if self.debug: stats.update(debug_stats) if self.learner.stats: stats["learner"] = self.learner.stats return dict(PolicyOptimizer.stats(self), **stats)
def foreach_evaluator(self, func): """Apply the given function to each evaluator instance.""" local_result = [func(self.local_evaluator)] remote_results = ray_get_and_free( [ev.apply.remote(func) for ev in self.remote_evaluators]) return local_result + remote_results
def collect_episodes(local_worker=None, remote_workers=[], timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" if remote_workers: pending = [ a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers ] collected, _ = ray.wait(pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) num_metric_batches_dropped = len(pending) - len(collected) if pending and len(collected) == 0: raise ValueError( "Timed out waiting for metrics from workers. You can " "configure this timeout with `collect_metrics_timeout`.") metric_lists = ray_get_and_free(collected) else: metric_lists = [] num_metric_batches_dropped = 0 if local_worker: metric_lists.append(local_worker.get_metrics()) episodes = [] for metrics in metric_lists: episodes.extend(metrics) return episodes, num_metric_batches_dropped
def collect_episodes(local_worker=None, remote_workers=[], to_be_collected=[], timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" if remote_workers: pending = [ a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers ] + to_be_collected collected, to_be_collected = ray.wait(pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) if pending and len(collected) == 0: logger.warning( "WARNING: collected no metrics in {} seconds".format( timeout_seconds)) metric_lists = ray_get_and_free(collected) else: metric_lists = [] if local_worker: metric_lists.append(local_worker.get_metrics()) episodes = [] for metrics in metric_lists: episodes.extend(metrics) return episodes, to_be_collected
def iter_train_batches(self): assert self.initialized, "Must call init() before using this class." for agg, batches in self.agg_tasks.completed_prefetch(): for b in ray_get_and_free(batches): self.num_sent_since_broadcast += 1 yield b agg.set_weights.remote(self.broadcasted_weights) self.agg_tasks.add(agg, agg.get_train_batches.remote()) self.num_batches_processed += 1
def checkpoint_and_set_static_policy_distribution_on_train_result_callback( params): trainer = params['trainer'] result = params['result'] result['psro_policy_num'] = trainer.claimed_policy_num evo_update(params) if not hasattr(trainer, 'next_refresh_steps'): trainer.next_refresh_steps = CHECKPOINT_AND_REFRESH_LIVE_TABLE_EVERY_N_STEPS if not hasattr(trainer, 'new_payoff_table_promise'): trainer.new_payoff_table_promise = None if result['timesteps_total'] >= trainer.next_refresh_steps: trainer.next_refresh_steps = max( trainer.next_refresh_steps + CHECKPOINT_AND_REFRESH_LIVE_TABLE_EVERY_N_STEPS, result['timesteps_total'] + 1) # do checkpoint _do_live_policy_checkpoint( trainer=trainer, training_iteration=result['training_iteration']) if not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS: # wait for all other learners to also reach this point before continuing ray_get_and_free( trainer.live_table_tracker. wait_at_barrier_for_other_learners.remote()) if not trainer.are_all_lower_policies_finished: trainer.payoff_table_needs_update_started = True # refresh payoff table/selection probs if trainer.payoff_table_needs_update_started and trainer.new_payoff_table_promise is None: trainer.new_payoff_table_promise = trainer.live_table_tracker.get_live_payoff_table_dill_pickled.remote( first_wait_for_n_seconds=2) trainer.payoff_table_needs_update_started = False if trainer.new_payoff_table_promise is not None: _process_new_live_payoff_table_result_if_ready( trainer=trainer, block_until_result_is_ready= not PIPELINE_LIVE_PAYOFF_TABLE_CALC_IS_ASYNCHRONOUS)
def main(): ray.init() config = es.DEFAULT_CONFIG.copy() config['env_config'] = {'root_dir': './data/wiki_test'} config['num_workers'] = args.num_workers config['episodes_per_batch'] = args.num_episodes config['train_batch_size'] = args.num_rollouts env = Writing trainer = es.ESTrainer.remote(config, env) # Can optionally call trainer.restore(path) to load a checkpoint. for i in range(1000): # Perform one iteration of training the policy with PPO result = ray_get_and_free(trainer.train.remote()) print(pretty_print(result)) if i % 100 == 0: checkpoint = ray_get_and_free(trainer.save.remote()) print("checkpoint saved at", checkpoint)
def foreach_worker_with_index(self, func): """Apply the given function to each worker instance. The index will be passed as the second arg to the given function. """ local_result = [func(self.local_worker(), 0)] remote_results = ray_get_and_free([ w.apply.remote(func, i + 1) for i, w in enumerate(self.remote_workers()) ]) return local_result + remote_results
def next(self): """Return the next batch of experiences read. Returns: SampleBatch or MultiAgentBatch read. """ batches = [] for dp in self.data_processors: batches.append(ray_get_and_free(dp.next.remote())) batch = MultiAgentBatch.concat_samples(samples=batches) return batch
def foreach_evaluator_with_index(self, func): """Apply the given function to each evaluator instance. The index will be passed as the second arg to the given function. """ local_result = [func(self.local_evaluator, 0)] remote_results = ray_get_and_free([ ev.apply.remote(func, i + 1) for i, ev in enumerate(self.remote_evaluators) ]) return local_result + remote_results
def _try_recover(self): """Try to identify and blacklist any unhealthy workers. This method is called after an unexpected remote error is encountered from a worker. It issues check requests to all current workers and blacklists any that respond with error. If no healthy workers remain, an error is raised. """ if not self._has_policy_optimizer(): raise NotImplementedError( "Recovery is not supported for this algorithm") logger.info("Health checking all workers...") checks = [] for ev in self.optimizer.remote_evaluators: _, obj_id = ev.sample_with_count.remote() checks.append(obj_id) healthy_evaluators = [] for i, obj_id in enumerate(checks): ev = self.optimizer.remote_evaluators[i] try: ray_get_and_free(obj_id) healthy_evaluators.append(ev) logger.info("Worker {} looks healthy".format(i + 1)) except RayError: logger.exception("Blacklisting worker {}".format(i + 1)) try: ev.__ray_terminate__.remote() except Exception: logger.exception("Error terminating unhealthy worker") if len(healthy_evaluators) < 1: raise RuntimeError( "Not enough healthy workers remain to continue.") self.optimizer.reset(healthy_evaluators)
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: if self.workers.remote_workers(): batches = ray_get_and_free( [e.sample.remote() for e in self.workers.remote_workers()]) else: batches = [self.workers.local_worker().sample()] # Handle everything as if multiagent tmp = [] for batch in batches: if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) tmp.append(batch) batches = tmp for batch in batches: if batch.count > self.max_buffer_size: raise ValueError( "The size of a single sample batch exceeds the replay " "buffer size ({} > {})".format(batch.count, self.max_buffer_size)) self.replay_buffer.append(batch) self.num_steps_sampled += batch.count self.buffer_size += batch.count while self.buffer_size > self.max_buffer_size: evicted = self.replay_buffer.pop(0) self.buffer_size -= evicted.count if self.num_steps_sampled >= self.replay_starts and self.num_steps_sampled % self.train_every == 0: iter_extra_fetches = defaultdict(list) with self.grad_timer: for i in range(self.num_sgd_iter): batch_fetches = self._sgd_step() for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) self.grad_timer.push_units_processed(self.train_batch_size * self.num_sgd_iter) return averaged(iter_extra_fetches) else: self.grad_timer = TimerStat() self.learner_stats = {} return {}
def _augment_with_replay(self, sample_futures): def can_replay(): num_needed = int( np.ceil(self.train_batch_size / self.sample_batch_size)) return len(self.replay_batches) > num_needed for ev, sample_batch in sample_futures: sample_batch = ray_get_and_free(sample_batch) yield ev, sample_batch if can_replay(): f = self.replay_proportion while random.random() < f: f -= 1 replay_batch = random.choice(self.replay_batches) self.num_replayed += replay_batch.count yield None, replay_batch
def foreach_policy(self, func): """Apply the given function to each worker's (policy, policy_id) tuple. Args: func (callable): A function - taking a Policy and its ID - that is called on all workers' Policies. Returns: List[any]: The list of return values of func over all workers' policies. """ local_results = self.local_worker().foreach_policy(func) remote_results = [] for worker in self.remote_workers(): res = ray_get_and_free( worker.apply.remote(lambda w: w.foreach_policy(func))) remote_results.extend(res) return local_results + remote_results
def synchronize(local_filters, remotes, update_remote=True): """Aggregates all filters from remote evaluators. Local copy is updated and then broadcasted to all remote evaluators. Args: local_filters (dict): Filters to be synchronized. remotes (list): Remote evaluators with filters. update_remote (bool): Whether to push updates to remote filters. """ remote_filters = ray_get_and_free( [r.get_filters.remote(flush_after=True) for r in remotes]) for rf in remote_filters: for k in local_filters: local_filters[k].apply_changes(rf[k], with_buffer=False) if update_remote: copies = {k: v.as_serializable() for k, v in local_filters.items()} remote_copy = ray.put(copies) [r.sync_filters.remote(remote_copy) for r in remotes]
def foreach_trainable_policy(self, func): """Apply `func` to all workers' Policies iff in `policies_to_train`. Args: func (callable): A function - taking a Policy and its ID - that is called on all workers' Policies in `worker.policies_to_train`. Returns: List[any]: The list of n return values of all `func([trainable policy], [ID])`-calls. """ local_results = self.local_worker().foreach_trainable_policy(func) remote_results = [] for worker in self.remote_workers(): res = ray_get_and_free( worker.apply.remote( lambda w: w.foreach_trainable_policy(func))) remote_results.extend(res) return local_results + remote_results
def step(self): with self.update_weights_timer: if self.workers.remote_workers(): weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: if self.workers.remote_workers(): batch = SampleBatch.concat_samples( ray_get_and_free([ e.sample.remote() for e in self.workers.remote_workers() ])) else: batch = self.workers.local_worker().sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], None, row["action_logp"], # row["diversity_advantages"], row["diversity_rewards"], # row["diversity_value_targets"], # row["my_logits"], row["prev_actions"], row["prev_rewards"]) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count