Esempio n. 1
0
 def __call__(self, batch):
     pending = [worker.get_metrics.remote() for worker in self.workers]
     collected, to_be_collected = ray.wait(pending,
                                           num_returns=len(pending),
                                           timeout=self.TIMEOUT_SECONDS)
     if pending and len(collected) == 0:
         logger.warning(
             "WARNING: collected no worker metrics in {} seconds".format(
                 self.TIMEOUT_SECONDS))
         return batch
     metrics = ray.get(collected)
     stats = {
         'mcts': {},
         'mem': {},
     }
     stats['mcts'].update(metrics[0]['mcts'])
     action_space_size = 0
     for key in stats['mcts']:
         if 'action_' in key and len(key) > 12:
             # Count action_{i}_count keys and skip action_count key
             action_space_size += 1
     for metric in metrics[1:]:
         d = metric['mcts']
         for key in d:
             stats['mcts'][key] += d[key]
     for i in range(action_space_size):
         stats['mcts'][f'action_{i}_count_pct'] = stats['mcts'][
             f'action_{i}_count'] / stats['mcts']['action_count']
     for i, metric in enumerate(metrics):
         for key in metric['mem']:
             stats['mem'][f'worker_{i}_{key}'] = metric['mem'][key]
     LocalIterator.get_metrics().info.update(stats)
     return batch
Esempio n. 2
0
    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)
Esempio n. 3
0
 def __call__(self, item):
     actor, batch = item
     timestep = LocalIterator.get_metrics().counters[STEPS_SAMPLED_COUNTER]
     self.steps_since_broadcast[actor] += 1
     if self.steps_since_broadcast[actor] >= self.broadcast_interval:
         if self.learner_thread.weights_updated:
             self.weights = ray.put(
                 self.workers.local_worker().get_weights())
             self.steps_since_broadcast[actor] = 0
             self.learner_thread.weights_updated = False
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_broadcasts"] += 1
         actor.set_weights.remote(self.weights, timestep)
     # Also update global vars of the local worker.
     self.workers.local_worker().set_timestep(timestep)
Esempio n. 4
0
def Dequeue(input_queue: queue.Queue, check=lambda: True):
    """Dequeue data items from a queue.Queue instance.

    The dequeue is non-blocking, so Dequeue operations can executed with
    Enqueue via the Concurrently() operator.

    Arguments:
        input_queue (Queue): queue to pull items from.
        check (fn): liveness check. When this function returns false,
            Dequeue() will raise an error to halt execution.

    Examples:
        >>> queue = queue.Queue(100)
        >>> write_op = ParallelRollouts(...).for_each(Enqueue(queue))
        >>> read_op = Dequeue(queue)
        >>> combined_op = Concurrently([write_op, read_op], mode="async")
        >>> next(combined_op)
        SampleBatch(...)
    """
    if not isinstance(input_queue, queue.Queue):
        raise ValueError("Expected queue.Queue, got {}".format(
            type(input_queue)))

    def base_iterator(timeout=None):
        while check():
            try:
                item = input_queue.get_nowait()
                yield item
            except queue.Empty:
                yield _NextValueNotReady()
        raise RuntimeError("Error raised reading from queue")

    return LocalIterator(base_iterator, MetricsContext())
Esempio n. 5
0
    def shuffle(self,
                local_it: LocalIterator[T]) -> LocalIterator[pd.DataFrame]:
        shuffle_random = random.Random(self._seed)

        def apply_shuffle(it):
            buffer = []
            for item in it:
                if isinstance(item, _NextValueNotReady):
                    yield item
                else:
                    buffer.append(item)
                    if len(buffer) >= self._shuffle_buffer_size:
                        item = buffer.pop(
                            shuffle_random.randint(0,
                                                   len(buffer) - 1))
                        item = item.sample(frac=1, random_state=self._seed)
                        yield item
            while len(buffer) > 0:
                item = buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
                item = item.sample(frac=1, random_state=self._seed)
                yield item

        return LocalIterator(
            local_it.base_iterator,
            local_it.shared_metrics,
            local_it.local_transforms + [apply_shuffle],
            name=local_it.name +
            ".shuffle(shuffle_buffer_size={}, seed={})".format(
                self._shuffle_buffer_size,
                str(self._seed) if self._seed is not None else "None",
            ),
        )
Esempio n. 6
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             info = do_minibatch_sgd(batch, self.policies,
                                     self.workers.local_worker(),
                                     self.num_sgd_iter,
                                     self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = get_learner_stats(info)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Esempio n. 7
0
 def __call__(self, samples: SampleBatchType):
     _check_sample_batch_type(samples)
     metrics = LocalIterator.get_metrics()
     with metrics.timers[COMPUTE_GRADS_TIMER]:
         grad, info = self.workers.local_worker().compute_gradients(samples)
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     return grad, samples.count
Esempio n. 8
0
    def __call__(self, item):
        if not isinstance(item, tuple) or len(item) != 2:
            raise ValueError(
                "Input must be a tuple of (grad_dict, count), got {}".format(
                    item))
        gradients, count = item
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_TRAINED_COUNTER] += count

        apply_timer = metrics.timers[APPLY_GRADS_TIMER]
        with apply_timer:
            self.workers.local_worker().apply_gradients(gradients)
            apply_timer.push_units_processed(count)

        if self.update_all:
            if self.workers.remote_workers():
                with metrics.timers[WORKER_UPDATE_TIMER]:
                    weights = ray.put(
                        self.workers.local_worker().get_weights())
                    for e in self.workers.remote_workers():
                        e.set_weights.remote(weights)
        else:
            if metrics.cur_actor is None:
                raise ValueError("Could not find actor to update. When "
                                 "update_all=False, `cur_actor` must be set "
                                 "in the iterator context.")
            with metrics.timers[WORKER_UPDATE_TIMER]:
                weights = self.workers.local_worker().get_weights()
                metrics.cur_actor.set_weights.remote(weights)
Esempio n. 9
0
 def __call__(self, batch: SampleBatch):
     """
     Assumes that batch is ordered by step number, although
     different episodes can be intermixed.
     """
     next_batches = []
     for i, eps_id in enumerate(batch[SampleBatch.EPS_ID]):
         if eps_id not in self.episodes:
             assert not batch[SampleBatch.DONES][i]
             self.episodes[eps_id] = queue.Queue(maxsize=self.n)
         q = self.episodes[eps_id]
         q.put_nowait((batch, i))
         if batch[SampleBatch.DONES][i]:
             # Force calculation of search value with less than n steps
             s = q.qsize()
             for _ in range(s):
                 while q.qsize() < self.n:
                     q.put((batch, i), timeout=0.1)
                 next_batches.append(self.get_next_batch(q))
             del self.episodes[eps_id]
         elif q.qsize() == self.n:
             next_batches.append(self.get_next_batch(q))
     if not next_batches:
         return None
     metrics = LocalIterator.get_metrics()
     metrics.info['calculate_priorities_queue_count'] = len(self.episodes)
     return SampleBatch.concat_samples(next_batches)
Esempio n. 10
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     if metrics.counters["n"] > 2:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
     return x
Esempio n. 11
0
    def __call__(self, _):
        # Collect worker metrics.
        episodes, self.to_be_collected = collect_episodes(
            self.workers.local_worker(),
            self.workers.remote_workers(),
            self.to_be_collected,
            timeout_seconds=self.timeout_seconds)
        orig_episodes = list(episodes)
        missing = self.min_history - len(episodes)
        if missing > 0:
            episodes.extend(self.episode_history[-missing:])
            assert len(episodes) <= self.min_history
        self.episode_history.extend(orig_episodes)
        self.episode_history = self.episode_history[-self.min_history:]
        res = summarize_episodes(episodes, orig_episodes)

        # Add in iterator metrics.
        metrics = LocalIterator.get_metrics()
        timers = {}
        counters = {}
        info = {}
        info.update(metrics.info)
        for k, counter in metrics.counters.items():
            counters[k] = counter
        for k, timer in metrics.timers.items():
            timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
            if timer.has_units_processed():
                timers["{}_throughput".format(k)] = round(
                    timer.mean_throughput, 3)
        res.update({
            "num_healthy_workers": len(self.workers.remote_workers()),
            "timesteps_total": metrics.counters[STEPS_SAMPLED_COUNTER],
        })
        res["timers"] = timers
        res["info"] = info
        res["info"].update(counters)
        relevant = [
            "info", "custom_metrics", "sampler_perf", "timesteps_total",
            "policy_reward_mean", "episode_len_mean"
        ]

        d = {k: res[k] for k in relevant}
        d["evaluation"] = res.get("evaluation", {})

        if self.log_to_neptune:
            metrics_to_be_logged = ["info", "evaluation"]

            def log_metric(metrics, base_string=''):
                if isinstance(metrics, dict):
                    for k in metrics:
                        log_metric(metrics[k], base_string + '{}_'.format(k))
                else:
                    neptune.log_metric(base_string, metrics)

            for k in d:
                if k in metrics_to_be_logged:
                    log_metric(d[k], base_string='{}_'.format(k))

        return d
Esempio n. 12
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the metrics context is shared.
     if metrics.counters["n"] >= 2:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
     return x
Esempio n. 13
0
 def update_prio_and_stats(item):
     actor, prio_dict, count = item
     actor.update_priorities.remote(prio_dict)
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_TRAINED_COUNTER] += count
     metrics.timers["learner_dequeue"] = learner_thread.queue_timer
     metrics.timers["learner_grad"] = learner_thread.grad_timer
     metrics.timers["learner_overall"] = learner_thread.overall_timer
Esempio n. 14
0
 def __call__(self, item):
     (grads, info), count = item
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_SAMPLED_COUNTER] += count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
                                          self.fetch_start_time)
     return grads, count
Esempio n. 15
0
def ParallelRollouts(workers: WorkerSet,
                     mode="bulk_sync") -> LocalIterator[SampleBatch]:
    """Operator to collect experiences in parallel from rollout workers.

    If there are no remote workers, experiences will be collected serially from
    the local worker instance instead.

    Arguments:
        workers (WorkerSet): set of rollout workers to use.
        mode (str): One of {'async', 'bulk_sync'}.
            - In 'async' mode, batches are returned as soon as they are
              computed by rollout workers with no order guarantees.
            - In 'bulk_sync' mode, we collect one batch from each worker
              and concatenate them together into a large batch to return.

    Returns:
        A local iterator over experiences collected in parallel.

    Examples:
        >>> rollouts = ParallelRollouts(workers, mode="async")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        50  # config.sample_batch_size

        >>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
        >>> batch = next(rollouts)
        >>> print(batch.count)
        200  # config.sample_batch_size * config.num_workers

    Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
    """
    def report_timesteps(batch):
        metrics = LocalIterator.get_metrics()
        metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
        return batch

    if not workers.remote_workers():
        # Handle the serial sampling case.
        def sampler(_):
            while True:
                yield workers.local_worker().sample()

        return (LocalIterator(sampler,
                              MetricsContext()).for_each(report_timesteps))

    # Create a parallel iterator over generated experiences.
    rollouts = from_actors(workers.remote_workers())

    if mode == "bulk_sync":
        return rollouts \
            .batch_across_shards() \
            .for_each(lambda batches: SampleBatch.concat_samples(batches)) \
            .for_each(report_timesteps)
    elif mode == "async":
        return rollouts.gather_async().for_each(report_timesteps)
    else:
        raise ValueError(
            "mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
Esempio n. 16
0
 def __call__(self, _):
     metrics = LocalIterator.get_metrics()
     cur_ts = metrics.counters[self.metric]
     last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
     if cur_ts - last_update > self.target_update_freq:
         self.workers.local_worker().foreach_trainable_policy(
             lambda p, _: p.update_target())
         metrics.counters[NUM_TARGET_UPDATES] += 1
         metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
Esempio n. 17
0
def LocalReplay(replay_buffer: ReplayBuffer, train_batch_size: int):
    """Replay experiences from a local buffer instance.

    This should be combined with the StoreToReplayBuffer operation using the
    Concurrently() operator.

    Arguments:
        replay_buffer (ReplayBuffer): Buffer to replay experiences from.
        train_batch_size (int): Batch size of fetches from the buffer.

    Examples:
        >>> actors = [ReplayActor.remote() for _ in range(4)]
        >>> replay_op = ParallelReplay(actors)
        >>> next(replay_op)
        SampleBatch(...)
    """
    assert isinstance(replay_buffer, ReplayBuffer)
    replay_buffers = {DEFAULT_POLICY_ID: replay_buffer}
    # TODO(ekl) support more options, or combine with ParallelReplay (?)
    synchronize_sampling = False
    prioritized_replay_beta = None

    def gen_replay(timeout):
        while True:
            samples = {}
            idxes = None
            for policy_id, replay_buffer in replay_buffers.items():
                if synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    metrics = LocalIterator.get_metrics()
                    num_steps_trained = metrics.counters[STEPS_TRAINED_COUNTER]
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=prioritized_replay_beta.value(num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            yield MultiAgentBatch(samples, train_batch_size)

    return LocalIterator(gen_replay, MetricsContext())
Esempio n. 18
0
 def __call__(self, item):
     if self.delay_steps <= 0:
         return True
     metrics = LocalIterator.get_metrics()
     now = metrics.counters[STEPS_SAMPLED_COUNTER]
     if now - self.last_called > self.delay_steps:
         self.last_called = now
         return True
     return False
Esempio n. 19
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the metrics context is shared recursively.
     print(metrics.counters)
     if metrics.counters["n"] >= 3:
         assert "foo" in metrics.counters
         assert "bar" in metrics.counters
         assert "baz" in metrics.counters
     return x
Esempio n. 20
0
 def update_prio_and_stats(item):
     actor, prio_dict, count = item
     actor.update_priorities.remote(prio_dict)
     metrics = LocalIterator.get_metrics()
     # Manually update the steps trained counter since the learner thread
     # is executing outside the pipeline.
     metrics.counters[STEPS_TRAINED_COUNTER] += count
     metrics.timers["learner_dequeue"] = learner_thread.queue_timer
     metrics.timers["learner_grad"] = learner_thread.grad_timer
     metrics.timers["learner_overall"] = learner_thread.overall_timer
Esempio n. 21
0
 def verify_metrics(x):
     metrics = LocalIterator.get_metrics()
     metrics.counters["n"] += 1
     # Check the unioned iterator gets a new metric context.
     assert "foo" not in metrics.counters
     assert "bar" not in metrics.counters
     # Check parent metrics are accessible.
     if metrics.counters["n"] > 2:
         assert "foo" in metrics.parent_metrics[0].counters
         assert "bar" in metrics.parent_metrics[1].counters
     return x
Esempio n. 22
0
 def __call__(self, items):
     for item in items:
         info, count = item
         metrics = LocalIterator.get_metrics()
         metrics.counters[STEPS_SAMPLED_COUNTER] += count
         metrics.counters[STEPS_TRAINED_COUNTER] += count
         metrics.info[LEARNER_INFO] = info
     # Since SGD happens remotely, the time delay between fetch and
     # completion is approximately the SGD step time.
     metrics.timers[LEARN_ON_BATCH_TIMER].push(time.perf_counter() -
                                               self.fetch_start_time)
Esempio n. 23
0
 def __call__(self, item):
     actor, batch = item
     self.steps_since_broadcast += 1
     if (self.steps_since_broadcast >= self.broadcast_interval
             and self.learner_thread.weights_updated):
         self.weights = ray.put(self.workers.local_worker().get_weights())
         self.steps_since_broadcast = 0
         self.learner_thread.weights_updated = False
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_broadcasts"] += 1
     actor.set_weights.remote(self.weights, _get_global_vars())
Esempio n. 24
0
File: nfsp.py Progetto: indylab/nxdo
def execution_plan(workers: WorkerSet,
                   config: TrainerConfigDict) -> LocalIterator[dict]:
    """Execution plan of the DQN algorithm. Defines the distributed dataflow.

    Args:
        workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
            of the Trainer.
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        LocalIterator[dict]: A local iterator over training metrics.
    """

    replay_buffer_actor = ReservoirReplayActor.remote(
        num_shards=1,
        learning_starts=config["learning_starts"],
        buffer_size=config["buffer_size"],
        replay_batch_size=config["train_batch_size"],
        replay_mode=config["multiagent"]["replay_mode"],
        replay_sequence_length=config["replay_sequence_length"],
    )

    # Store a handle for the replay buffer actor in the local worker
    workers.local_worker().replay_buffer_actor = replay_buffer_actor

    # Read and train on experiences from the replay buffer. Every batch
    # returned from the Replay iterator is passed to TrainOneStep to
    # take a SGD step.
    post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)

    print("running replay op..")

    def gen_replay(_):
        while True:
            item = ray.get(replay_buffer_actor.replay.remote())
            if item is None:
                yield _NextValueNotReady()
            else:
                yield item

    replay_op = LocalIterator(gen_replay, SharedMetrics()) \
        .for_each(lambda x: post_fn(x, workers, config)) \
        .for_each(TrainOneStep(workers))

    replay_op = StandardMetricsReporting(replay_op, workers, config)

    replay_op = map(
        lambda x: x
        if not isinstance(x, _NextValueNotReady) else {}, replay_op)

    return replay_op
Esempio n. 25
0
def Replay(
    *,
    local_buffer: Optional[MultiAgentReplayBuffer] = None,
    num_items_to_replay: int = 1,
    actors: Optional[List[ActorHandle]] = None,
    num_async: int = 4,
) -> LocalIterator[SampleBatchType]:
    """Replay experiences from the given buffer or actors.

    This should be combined with the StoreToReplayActors operation using the
    Concurrently() operator.

    Args:
        local_buffer: Local buffer to use. Only one of this and replay_actors
            can be specified.
        num_items_to_replay: Number of items to sample from buffer
        actors: List of replay actors. Only one of this and local_buffer
            can be specified.
        num_async: In async mode, the max number of async requests in flight
            per actor.

    Examples:
        >>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
        >>> actors = [ # doctest: +SKIP
        ...     multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)]
        >>> replay_op = Replay(actors=actors, # doctest: +SKIP
        ...     num_items_to_replay=batch_size)
        >>> next(replay_op) # doctest: +SKIP
        SampleBatch(...)
    """

    if local_buffer is not None and actors is not None:
        raise ValueError(
            "Exactly one of local_buffer and replay_actors must be given.")

    if actors is not None:
        for actor in actors:
            actor.make_iterator.remote(num_items_to_replay=num_items_to_replay)
        replay = from_actors(actors)
        return replay.gather_async(
            num_async=num_async).filter(lambda x: x is not None)

    def gen_replay(_):
        while True:
            item = local_buffer.sample(num_items_to_replay)
            if item is None:
                yield _NextValueNotReady()
            else:
                yield item

    return LocalIterator(gen_replay, SharedMetrics())
Esempio n. 26
0
 def __call__(self, batch: SampleBatch) -> List[dict]:
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         info = self.workers.local_worker().learn_on_batch(batch)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     metrics.info[LEARNER_INFO] = info[LEARNER_STATS_KEY]
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights)
     return info
Esempio n. 27
0
 def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
     _check_sample_batch_type(batch)
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []
Esempio n. 28
0
def test_metrics(ray_start_regular_shared):
    it = from_items([1, 2, 3, 4], num_shards=1)
    it2 = from_items([1, 2, 3, 4], num_shards=1)

    def f(x):
        metrics = LocalIterator.get_metrics()
        metrics.counters["foo"] += x
        return metrics.counters["foo"]

    it = it.gather_sync().for_each(f)
    it2 = it2.gather_sync().for_each(f)

    # Context cannot be accessed outside the iterator.
    with pytest.raises(ValueError):
        LocalIterator.get_metrics()

    # Tests iterators have isolated contexts.
    assert it.take(4) == [1, 3, 6, 10]
    assert it2.take(4) == [1, 3, 6, 10]

    # Context cannot be accessed outside the iterator.
    with pytest.raises(ValueError):
        LocalIterator.get_metrics()
Esempio n. 29
0
 def __call__(self, item):
     actor, batch = item
     self.steps_since_update[actor] += batch.count
     if self.steps_since_update[actor] >= self.max_weight_sync_delay:
         # Note that it's important to pull new weights once
         # updated to avoid excessive correlation between actors.
         if self.weights is None or self.learner_thread.weights_updated:
             self.learner_thread.weights_updated = False
             self.weights = ray.put(
                 self.workers.local_worker().get_weights())
         actor.set_weights.remote(self.weights)
         self.steps_since_update[actor] = 0
         # Update metrics.
         metrics = LocalIterator.get_metrics()
         metrics.counters["num_weight_syncs"] += 1
Esempio n. 30
0
 def __call__(self, batch: SampleBatch) -> List[SampleBatch]:
     if not isinstance(batch, SampleBatch):
         raise ValueError("Expected type SampleBatch, got {}: {}".format(
             type(batch), batch))
     self.buffer.append(batch)
     self.count += batch.count
     if self.count >= self.min_batch_size:
         out = SampleBatch.concat_samples(self.buffer)
         timer = LocalIterator.get_metrics().timers[SAMPLE_TIMER]
         timer.push(time.perf_counter() - self.batch_start_time)
         timer.push_units_processed(self.count)
         self.batch_start_time = None
         self.buffer = []
         self.count = 0
         return [out]
     return []