Esempio n. 1
0
 def __call__(self, batch):
     pending = [worker.get_metrics.remote() for worker in self.workers]
     collected, to_be_collected = ray.wait(pending,
                                           num_returns=len(pending),
                                           timeout=self.TIMEOUT_SECONDS)
     if pending and len(collected) == 0:
         logger.warning(
             "WARNING: collected no worker metrics in {} seconds".format(
                 self.TIMEOUT_SECONDS))
         return batch
     metrics = ray.get(collected)
     stats = {
         'mcts': {},
         'mem': {},
     }
     stats['mcts'].update(metrics[0]['mcts'])
     action_space_size = 0
     for key in stats['mcts']:
         if 'action_' in key and len(key) > 12:
             # Count action_{i}_count keys and skip action_count key
             action_space_size += 1
     for metric in metrics[1:]:
         d = metric['mcts']
         for key in d:
             stats['mcts'][key] += d[key]
     for i in range(action_space_size):
         stats['mcts'][f'action_{i}_count_pct'] = stats['mcts'][
             f'action_{i}_count'] / stats['mcts']['action_count']
     for i, metric in enumerate(metrics):
         for key in metric['mem']:
             stats['mem'][f'worker_{i}_{key}'] = metric['mem'][key]
     LocalIterator.get_metrics().info.update(stats)
     return batch
Esempio n. 2
0
    def __call__(self, item):
        if not isinstance(item, tuple) or len(item) != 2:
            raise ValueError(
                "Input must be a tuple of (grad_dict, count), got {}".format(
                    item))
        gradients, count = item
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_TRAINED_COUNTER] += count

        apply_timer = metrics.timers[APPLY_GRADS_TIMER]
        with apply_timer:
            self.workers.local_worker().apply_gradients(gradients)
            apply_timer.push_units_processed(count)

        if self.update_all:
            if self.workers.remote_workers():
                with metrics.timers[WORKER_UPDATE_TIMER]:
                    weights = ray.put(
                        self.workers.local_worker().get_weights())
                    for e in self.workers.remote_workers():
                        e.set_weights.remote(weights)
        else:
            if metrics.cur_actor is None:
                raise ValueError("Could not find actor to update. When "
                                 "update_all=False, `cur_actor` must be set "
                                 "in the iterator context.")
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = self.workers.local_worker().get_weights()
                metrics.cur_actor.set_weights.remote(weights)
Esempio n. 3
0
 def __call__(self, samples: SampleBatchType):
     _check_sample_batch_type(samples)
     metrics = LocalIterator.get_metrics()
     with metrics.timers[COMPUTE_GRADS_TIMER]:
         grad, info = self.workers.local_worker().compute_gradients(samples)
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     return grad, samples.count
Esempio n. 4
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             info = do_minibatch_sgd(batch, self.policies,
                                     self.workers.local_worker(),
                                     self.num_sgd_iter,
                                     self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = get_learner_stats(info)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Esempio n. 5
0
 def __call__(self, batch: SampleBatch):
     """
     Assumes that batch is ordered by step number, although
     different episodes can be intermixed.
     """
     next_batches = []
     for i, eps_id in enumerate(batch[SampleBatch.EPS_ID]):
         if eps_id not in self.episodes:
             assert not batch[SampleBatch.DONES][i]
             self.episodes[eps_id] = queue.Queue(maxsize=self.n)
         q = self.episodes[eps_id]
         q.put_nowait((batch, i))
         if batch[SampleBatch.DONES][i]:
             # Force calculation of search value with less than n steps
             s = q.qsize()
             for _ in range(s):
                 while q.qsize() < self.n:
                     q.put((batch, i), timeout=0.1)
                 next_batches.append(self.get_next_batch(q))
             del self.episodes[eps_id]
         elif q.qsize() == self.n:
             next_batches.append(self.get_next_batch(q))
     if not next_batches:
         return None
     metrics = LocalIterator.get_metrics()
     metrics.info['calculate_priorities_queue_count'] = len(self.episodes)
     return SampleBatch.concat_samples(next_batches)
Esempio n. 6
0
    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)
Esempio n. 7
0
 def __call__(self, item):
     actor, batch = item
     timestep = LocalIterator.get_metrics().counters[STEPS_SAMPLED_COUNTER]
     self.steps_since_broadcast[actor] += 1
     if self.steps_since_broadcast[actor] >= self.broadcast_interval:
         if self.learner_thread.weights_updated:
             self.weights = ray.put(
                 self.workers.local_worker().get_weights())
             self.steps_since_broadcast[actor] = 0
             self.learner_thread.weights_updated = False
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_broadcasts"] += 1
         actor.set_weights.remote(self.weights, timestep)
     # Also update global vars of the local worker.
     self.workers.local_worker().set_timestep(timestep)
Esempio n. 8
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     if metrics.counters["n"] > 2:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
     return x
Esempio n. 9
0
    def __call__(self, _):
        # Collect worker metrics.
        episodes, self.to_be_collected = collect_episodes(
            self.workers.local_worker(),
            self.workers.remote_workers(),
            self.to_be_collected,
            timeout_seconds=self.timeout_seconds)
        orig_episodes = list(episodes)
        missing = self.min_history - len(episodes)
        if missing > 0:
            episodes.extend(self.episode_history[-missing:])
            assert len(episodes) <= self.min_history
        self.episode_history.extend(orig_episodes)
        self.episode_history = self.episode_history[-self.min_history:]
        res = summarize_episodes(episodes, orig_episodes)

        # Add in iterator metrics.
        metrics = LocalIterator.get_metrics()
        timers = {}
        counters = {}
        info = {}
        info.update(metrics.info)
        for k, counter in metrics.counters.items():
            counters[k] = counter
        for k, timer in metrics.timers.items():
            timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
            if timer.has_units_processed():
                timers["{}_throughput".format(k)] = round(
                    timer.mean_throughput, 3)
        res.update({
            "num_healthy_workers": len(self.workers.remote_workers()),
            "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
        })
        res["timers"] = timers
        res["info"] = info
        res["info"].update(counters)
        relevant = [
            "info", "custom_metrics", "sampler_perf", "timesteps_total",
            "policy_reward_mean", "episode_len_mean"
        ]

        d = {k: res[k] for k in relevant}
        d["evaluation"] = res.get("evaluation", {})

        if self.log_to_neptune:
            metrics_to_be_logged = ["info", "evaluation"]

            def log_metric(metrics, base_string=''):
                if isinstance(metrics, dict):
                    for k in metrics:
                        log_metric(metrics[k], base_string + '{}_'.format(k))
                else:
                    neptune.log_metric(base_string, metrics)

            for k in d:
                if k in metrics_to_be_logged:
                    log_metric(d[k], base_string='{}_'.format(k))

        return d
Esempio n. 10
0
 def update_prio_and_stats(item):
     actor, prio_dict, count = item
     actor.update_priorities.remote(prio_dict)
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_TRAINED_COUNTER] += count
     metrics.timers["learner_dequeue"] = learner_thread.queue_timer
     metrics.timers["learner_grad"] = learner_thread.grad_timer
     metrics.timers["learner_overall"] = learner_thread.overall_timer
Esempio n. 11
0
 def __call__(self, item):
     (grads, info), count = item
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_SAMPLED_COUNTER] += count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
                                          self.fetch_start_time)
     return grads, count
Esempio n. 12
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the metrics context is shared.
     if metrics.counters["n"] >= 2:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
     return x
Esempio n. 13
0
 def __call__(self, item):
     if self.delay_steps <= 0:
         return True
     metrics = LocalIterator.get_metrics()
     now = metrics.counters[STEPS_SAMPLED_COUNTER]
     if now - self.last_called > self.delay_steps:
         self.last_called = now
         return True
     return False
Esempio n. 14
0
 def __call__(self, _):
     metrics = LocalIterator.get_metrics()
     cur_ts = metrics.counters[self.metric]
     last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
     if cur_ts - last_update > self.target_update_freq:
         self.workers.local_worker().foreach_trainable_policy(
             lambda p, _: p.update_target())
         metrics.counters[NUM_TARGET_UPDATES] += 1
         metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
Esempio n. 15
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the metrics context is shared recursively.
     print(metrics.counters)
     if metrics.counters["n"] >= 3:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
         assert "baz" in metrics.counters
     return x
Esempio n. 16
0
 def update_prio_and_stats(item):
     actor, prio_dict, count = item
     actor.update_priorities.remote(prio_dict)
     metrics = LocalIterator.get_metrics()
     # Manually update the steps trained counter since the learner thread
     # is executing outside the pipeline.
     metrics.counters[STEPS_TRAINED_COUNTER] += count
     metrics.timers["learner_dequeue"] = learner_thread.queue_timer
     metrics.timers["learner_grad"] = learner_thread.grad_timer
     metrics.timers["learner_overall"] = learner_thread.overall_timer
Esempio n. 17
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the unioned iterator gets a new metric context.
     assert "foo" not in metrics.counters
     assert "bar" not in metrics.counters
     # Check parent metrics are accessible.
     if metrics.counters["n"] > 2:
         assert "foo" in metrics.parent_metrics[0].counters
         assert "bar" in metrics.parent_metrics[1].counters
     return x
Esempio n. 18
0
 def __call__(self, items):
     for item in items:
         info, count = item
         metrics = LocalIterator.get_metrics()
         metrics.counters[STEPS_SAMPLED_COUNTER] += count
         metrics.counters[STEPS_TRAINED_COUNTER] += count
         metrics.info[LEARNER_INFO] = info
     # Since SGD happens remotely, the time delay between fetch and
     # completion is approximately the SGD step time.
     metrics.timers[LEARN_ON_BATCH_TIMER].push(time.perf_counter() -
                                               self.fetch_start_time)
Esempio n. 19
0
 def __call__(self, item):
     actor, batch = item
     self.steps_since_broadcast += 1
     if (self.steps_since_broadcast >= self.broadcast_interval
             and self.learner_thread.weights_updated):
         self.weights = ray.put(self.workers.local_worker().get_weights())
         self.steps_since_broadcast = 0
         self.learner_thread.weights_updated = False
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_broadcasts"] += 1
     actor.set_weights.remote(self.weights, _get_global_vars())
Esempio n. 20
0
def test_metrics(ray_start_regular_shared):
    it = from_items([1, 2, 3, 4], num_shards=1)
    it2 = from_items([1, 2, 3, 4], num_shards=1)

    def f(x):
        metrics = LocalIterator.get_metrics()
        metrics.counters["foo"] += x
        return metrics.counters["foo"]

    it = it.gather_sync().for_each(f)
    it2 = it2.gather_sync().for_each(f)

    # Context cannot be accessed outside the iterator.
    with pytest.raises(ValueError):
        LocalIterator.get_metrics()

    # Tests iterators have isolated contexts.
    assert it.take(4) == [1, 3, 6, 10]
    assert it2.take(4) == [1, 3, 6, 10]

    # Context cannot be accessed outside the iterator.
    with pytest.raises(ValueError):
        LocalIterator.get_metrics()
Esempio n. 21
0
 def __call__(self, batch: SampleBatch) -> List[dict]:
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         info = self.workers.local_worker().learn_on_batch(batch)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     metrics.info[LEARNER_INFO] = info[LEARNER_STATS_KEY]
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights)
     return info
Esempio n. 22
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Esempio n. 23
0
 def __call__(self, item):
     actor, batch = item
     self.steps_since_update[actor] += batch.count
     if self.steps_since_update[actor] >= self.max_weight_sync_delay:
         # Note that it's important to pull new weights once
         # updated to avoid excessive correlation between actors.
         if self.weights is None or self.learner_thread.weights_updated:
             self.learner_thread.weights_updated = False
             self.weights = ray.put(
                 self.workers.local_worker().get_weights())
         actor.set_weights.remote(self.weights)
         self.steps_since_update[actor] = 0
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_syncs"] += 1
Esempio n. 24
0
 def __call__(self, batch: SampleBatch) -> List[SampleBatch]:
     if not isinstance(batch, SampleBatch):
         raise ValueError("Expected type SampleBatch, got {}: {}".format(
             type(batch), batch))
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Esempio n. 25
0
 def __call__(self, batch: SampleBatchType) -> List[dict]:
     _check_sample_batch_type(batch)
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         info = self.workers.local_worker().learn_on_batch(batch)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return info
Esempio n. 26
0
    def __call__(self, _):
        # Collect worker metrics.
        episodes, self.to_be_collected = collect_episodes(
            self.workers.local_worker(),
            self.workers.remote_workers(),
            self.to_be_collected,
            timeout_seconds=self.timeout_seconds)
        orig_episodes = list(episodes)
        missing = self.min_history - len(episodes)
        if missing > 0:
            episodes.extend(self.episode_history[-missing:])
            assert len(episodes) <= self.min_history
        self.episode_history.extend(orig_episodes)
        self.episode_history = self.episode_history[-self.min_history:]
        res = summarize_episodes(episodes, orig_episodes)

        # Add in iterator metrics.
        metrics = LocalIterator.get_metrics()
        if metrics.parent_metrics:
            print("TODO: support nested metrics better")
        all_metrics = [metrics] + metrics.parent_metrics
        timers = {}
        counters = {}
        info = {}
        for metrics in all_metrics:
            info.update(metrics.info)
            for k, counter in metrics.counters.items():
                counters[k] = counter
            for k, timer in metrics.timers.items():
                timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
                if timer.has_units_processed():
                    timers["{}_throughput".format(k)] = round(
                        timer.mean_throughput, 3)
            res.update({
                "num_healthy_workers":
                len(self.workers.remote_workers()),
                "timesteps_total":
                metrics.counters[STEPS_SAMPLED_COUNTER],
            })
        res["timers"] = timers
        res["info"] = info
        res["info"].update(counters)
        return res
Esempio n. 27
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         if self.count > self.min_batch_size * 2:
             logger.info("Collected more training samples than expected "
                         "(actual={}, expected={}). ".format(
                             self.count, self.min_batch_size) +
                         "This may be because you have many workers or "
                         "long episodes in 'complete_episodes' batch mode.")
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Esempio n. 28
0
def _get_shared_metrics():
    """Return shared metrics for the training workflow.

    This only applies if this trainer has an execution plan."""
    return LocalIterator.get_metrics()
Esempio n. 29
0
def _get_global_vars():
    metrics = LocalIterator.get_metrics()
    return {"timestep": metrics.counters[STEPS_SAMPLED_COUNTER]}
Esempio n. 30
0
 def report_timesteps(batch):
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
     return batch