def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) if self.done: # Warmup phase done, simply return batch return [batch] metrics = _get_shared_metrics() timesteps_total = metrics.counters[STEPS_SAMPLED_COUNTER] self.buffer.append(batch) self.count += batch.count assert self.count == timesteps_total if timesteps_total < self.learning_starts: # Return emtpy if still in warmup return [] # Warmup just done if self.count > self.learning_starts * 2: logger.info( # pylint:disable=logging-fstring-interpolation "Collected more training samples than expected " f"(actual={self.count}, expected={self.learning_starts}). " "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) self.buffer = [] self.count = 0 self.done = True return [out]
def record_steps_trained(item): count, fetches = item metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count return item
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) if self.count_steps_by == "env_steps": self.count += batch.count else: assert isinstance(batch, MultiAgentBatch), \ "`count_steps_by=agent_steps` only allowed in multi-agent " \ "environments!" self.count += batch.agent_steps() if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def __call__(self, data_tuple): # Metaupdate Step samples = data_tuple[0] adapt_metrics_dict = data_tuple[1] for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) fetches = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() # Set worker tasks set_worker_tasks(self.workers) # Update KLS def update(pi, pi_id): assert "inner_kl" not in fetches, ( "inner_kl should be nested under policy id key", fetches) if pi_id in fetches: assert "inner_kl" in fetches[pi_id], (fetches, pi_id) pi.update_kls(fetches[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) self.workers.local_worker().foreach_trainable_policy(update) # Modify Reporting Metrics metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_TRAINED_COUNTER] += samples.count res = self.metric_gen.__call__(None) res.update(adapt_metrics_dict) return res
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: w = self.workers.local_worker() info = do_minibatch_sgd( batch, {p: w.get_policy(p) for p in self.policies}, w, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = get_learner_stats(info) learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: _check_sample_batch_type(samples) metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) metrics.info[LEARNER_INFO] = get_learner_stats(info) return grad, samples.count
def __call__(self, batch: MultiAgentBatch) -> List[SampleBatchType]: _check_sample_batch_type(batch) batch_count = batch.policy_batches[self.policy_id_to_count_for].count if self.drop_samples_for_other_agents: batch = MultiAgentBatch(policy_batches={ self.policy_id_to_count_for: batch.policy_batches[self.policy_id_to_count_for] }, env_steps=batch.policy_batches[ self.policy_id_to_count_for].count) self.buffer.append(batch) self.count += batch_count if self.count >= self.min_batch_size: if self.count > self.min_batch_size * 2: logger.info("Collected more training samples than expected " "(actual={}, expected={}). ".format( self.count, self.min_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []
def add_ppo_metrics(batch): print("PPO policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.count, "agent steps", batch.total()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_PPO"] += batch.total() return batch
def __call__(self, item: Tuple[ModelGradients, int]) -> None: if not isinstance(item, tuple) or len(item) != 2: raise ValueError( "Input must be a tuple of (grad_dict, count), got {}".format( item)) gradients, count = item metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += count apply_timer = metrics.timers[APPLY_GRADS_TIMER] with apply_timer: self.workers.local_worker().apply_gradients(gradients) apply_timer.push_units_processed(count) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) if self.update_all: if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) else: if metrics.current_actor is None: raise ValueError( "Could not find actor to update. When " "update_all=False, `current_actor` must be set " "in the iterator context.") with metrics.timers[WORKER_UPDATE_TIMER]: weights = self.workers.local_worker().get_weights( self.policies) metrics.current_actor.set_weights.remote( weights, _get_global_vars())
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: lw = self.workers.local_worker() info = do_minibatch_sgd( batch, {pid: lw.get_policy(pid) for pid in self.policies}, lw, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here # TODO(sven): Skips `custom_metrics` key from on_learn_on_batch # callback (shouldn't). metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = extract_stats( info, LEARNER_STATS_KEY) metrics.info["custom_metrics"] = extract_stats( info, "custom_metrics") learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count # Update weights - after learning on the local worker - on all remote # workers. if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def __call__(self, samples): # Dreamer training loop. for n in range(self.dreamer_train_iters): print(f"sub-iteration={n}/{self.dreamer_train_iters}") batch = self.episode_buffer.sample(self.batch_size) # if n == self.dreamer_train_iters - 1: # batch["log_gif"] = True fetches = self.worker.learn_on_batch(batch) # Custom Logging policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"] if "log_gif" in policy_fetches: gif = policy_fetches["log_gif"] policy_fetches["log_gif"] = self.postprocess_gif(gif) # Metrics Calculation metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_SAMPLED_COUNTER] = self.episode_buffer.timesteps metrics.counters[STEPS_SAMPLED_COUNTER] *= self.repeat res = collect_metrics(local_worker=self.worker) res["info"] = metrics.info res["info"].update(metrics.counters) res["timesteps_total"] = metrics.counters[STEPS_SAMPLED_COUNTER] self.episode_buffer.add(samples) return res
def __call__(self, item): actor, (info, samples, training_steps) = item metrics = _get_shared_metrics() metrics.counters[STEPS_TRAINED_COUNTER] += training_steps metrics.counters[STEPS_SAMPLED_COUNTER] += samples self.counters[actor] += 1 metrics.counters[ f"WorkerIteration/Worker{self.worker_idx[actor]}"] += 1 global_vars = _get_global_vars() self.workers.local_worker().set_global_vars(global_vars) actor.set_global_vars.remote(global_vars) if self.counters[actor] % self.broadcast_interval == 0: metrics.counters["num_weight_broadcasts"] += 1 with metrics.timers[WORKER_UPDATE_TIMER]: for pid, gw in self.global_weights.items(): def update_worker(w, alpha): return w.policy_map[pid].easgd_update(gw, alpha) diff = ray.get( actor.apply.remote(update_worker, self.alpha)) self.global_weights[pid] = EASGDUpdate.easgd_add( gw, diff, self.alpha) return info
def add_dqn_metrics(batch): print("DQN policy learning on samples from", batch.policy_batches.keys(), "env steps", batch.env_steps(), "agent steps", batch.env_steps()) metrics = _get_shared_metrics() metrics.counters["agent_steps_trained_DQN"] += batch.env_steps() return batch
def record_steps_trained(item): env_steps, agent_steps, fetches = item metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_steps metrics.counters[STEPS_TRAINED_COUNTER] += env_steps return item
def report_timesteps(batch): metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count if isinstance(batch, MultiAgentBatch): metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.agent_steps() else: metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count return batch
def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = get_learner_stats(info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count
def __call__(self, _): metrics = _get_shared_metrics() cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER] if cur_ts - self.prev_update > self.do_every: self.prev_update = cur_ts filepath = os.path.join(self.f_path, "{}_steps".format(cur_ts)) if not os.path.exists(filepath): os.makedirs(filepath) _dump_buffer_content(self.buffer, filepath)
def __call__(self, item: Any) -> bool: if self.delay_steps <= 0: return True metrics = _get_shared_metrics() now = metrics.counters[STEPS_SAMPLED_COUNTER] if now - self.last_called >= self.delay_steps: self.last_called = now return True return False
def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: _check_sample_batch_type(samples) metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients( samples, single_agent=True) # RolloutWorker.compute_gradients returned single-agent stats. metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info} return grad, samples.count
def update_prio_and_stats(item: Tuple["ActorHandle", dict, int]) -> None: actor, prio_dict, count = item actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner thread # is executing outside the pipeline. metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer
def __call__(self, _: Any) -> None: metrics = _get_shared_metrics() cur_ts = metrics.counters[self.metric] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update > self.target_update_freq: to_update = self.policies self.workers.local_worker().foreach_trainable_policy( lambda p, p_id: p_id in to_update and p.update_target()) metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
def __call__(self, item): (grads, info), count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = ({ DEFAULT_POLICY_ID: info } if LEARNER_STATS_KEY in info else info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count
def __call__(self, _: Any) -> Dict: # Collect worker metrics. episodes, self.to_be_collected = collect_episodes( self.workers.local_worker(), self.selected_workers or self.workers.remote_workers(), self.to_be_collected, timeout_seconds=self.timeout_seconds, ) orig_episodes = list(episodes) missing = self.min_history - len(episodes) if missing > 0: episodes = self.episode_history[-missing:] + episodes assert len(episodes) <= self.min_history self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-self.min_history :] res = summarize_episodes(episodes, orig_episodes, self.keep_custom_metrics) # Add in iterator metrics. metrics = _get_shared_metrics() custom_metrics_from_info = metrics.info.pop("custom_metrics", {}) timers = {} counters = {} info = {} info.update(metrics.info) for k, counter in metrics.counters.items(): counters[k] = counter for k, timer in metrics.timers.items(): timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) if timer.has_units_processed(): timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3) res.update( { "num_healthy_workers": len(self.workers.remote_workers()), "timesteps_total": ( metrics.counters[STEPS_TRAINED_COUNTER] if self.by_steps_trained else metrics.counters[STEPS_SAMPLED_COUNTER] ), # tune.Trainable uses timesteps_this_iter for tracking # total timesteps. "timesteps_this_iter": metrics.counters[ STEPS_TRAINED_THIS_ITER_COUNTER ], "agent_timesteps_total": metrics.counters.get( AGENT_STEPS_SAMPLED_COUNTER, 0 ), } ) res["timers"] = timers res["info"] = info res["info"].update(counters) res["custom_metrics"] = res.get("custom_metrics", {}) res["episode_media"] = res.get("episode_media", {}) res["custom_metrics"].update(custom_metrics_from_info) return res
def __call__(self, items): for item in items: info, count = item metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.counters[STEPS_TRAINED_COUNTER] += count metrics.info[LEARNER_INFO] = info # Since SGD happens remotely, the time delay between fetch and # completion is approximately the SGD step time. metrics.timers[LEARN_ON_BATCH_TIMER].push(time.perf_counter() - self.fetch_start_time)
def __call__(self, samples: SampleBatchType) -> Tuple[ModelGradients, int]: _check_sample_batch_type(samples) metrics = _get_shared_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) # RolloutWorker.compute_gradients returns pure single agent stats # in a non-multi agent setup. if isinstance(samples, MultiAgentBatch): metrics.info[LEARNER_INFO] = info else: metrics.info[LEARNER_INFO] = {DEFAULT_POLICY_ID: info} return grad, samples.count
def __call__(self, item): actor, batch = item self.steps_since_broadcast += 1 if (self.steps_since_broadcast >= self.broadcast_interval and self.learner_thread.weights_updated): self.weights = ray.put(self.workers.local_worker().get_weights()) self.steps_since_broadcast = 0 self.learner_thread.weights_updated = False # Update metrics. metrics = _get_shared_metrics() metrics.counters["num_weight_broadcasts"] += 1 actor.set_weights.remote(self.weights, _get_global_vars())
def __call__(self, _: Any) -> Dict: # Collect worker metrics. episodes, self.to_be_collected = collect_episodes( self.workers.local_worker(), self.selected_workers or self.workers.remote_workers(), self.to_be_collected, timeout_seconds=self.timeout_seconds) orig_episodes = list(episodes) missing = self.min_history - len(episodes) if missing > 0: episodes = self.episode_history[-missing:] + episodes assert len(episodes) <= self.min_history self.episode_history.extend(orig_episodes) self.episode_history = self.episode_history[-self.min_history:] res = summarize_episodes(episodes, orig_episodes) # Add in iterator metrics. metrics = _get_shared_metrics() custom_metrics_from_info = metrics.info.pop("custom_metrics", {}) timers = {} counters = {} info = {} info.update(metrics.info) for k, counter in metrics.counters.items(): counters[k] = counter for k, timer in metrics.timers.items(): timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3) if timer.has_units_processed(): timers["{}_throughput".format(k)] = round( timer.mean_throughput, 3) throughput = timer.mean_throughput with Log.timer(log=True, logger=self.logger, info="THROUGHPUT") as logging_metrics: logging_metrics.append(throughput) res.update({ "num_healthy_workers": len(self.workers.remote_workers()), "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER], "agent_timesteps_total": metrics.counters.get(AGENT_STEPS_SAMPLED_COUNTER, 0), }) res["timers"] = timers res["info"] = info res["info"].update(counters) res["custom_metrics"] = res.get("custom_metrics", {}) res["episode_media"] = res.get("episode_media", {}) res["custom_metrics"].update(custom_metrics_from_info) return res
def __call__(self, fetches): metrics = _get_shared_metrics() cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER] last_update = metrics.counters[LAST_TARGET_UPDATE_TS] if cur_ts - last_update > self.target_update_freq: metrics.counters[NUM_TARGET_UPDATES] += 1 metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts # Update Target Network self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) # Also update KL Coeff if self.config["use_kl_loss"]: self.update_kl(fetches)
def update_prio_and_stats( item: Tuple[ActorHandle, dict, int, int]) -> None: actor, prio_dict, env_count, agent_count = item if config.get("prioritized_replay"): actor.update_priorities.remote(prio_dict) metrics = _get_shared_metrics() # Manually update the steps trained counter since the learner # thread is executing outside the pipeline. metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count metrics.counters[STEPS_TRAINED_COUNTER] += env_count metrics.timers["learner_dequeue"] = learner_thread.queue_timer metrics.timers["learner_grad"] = learner_thread.grad_timer metrics.timers["learner_overall"] = learner_thread.overall_timer
def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]: _check_sample_batch_type(batch) self.buffer.append(batch) self.count += 1 if self.count >= self.num_episodes: out = SampleBatch.concat_samples(self.buffer) timer = _get_shared_metrics().timers[SAMPLE_TIMER] timer.push(time.perf_counter() - self.batch_start_time) timer.push_units_processed(self.count) self.batch_start_time = None self.buffer = [] self.count = 0 return [out] return []