コード例 #1
0
ファイル: test_taskpool.py プロジェクト: wuisawesome/ray
    def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(self, rayWaitMock):
        pool = TaskPool()
        # We need to hold onto the tasks for this test so that we can fail a
        # specific worker
        tasks = []

        for i in range(10):
            task = createMockWorkerAndObjectRef(i)
            pool.add(*task)
            tasks.append(task)

        # Simulate only some of the work being complete and fetch a couple of
        # tasks in order to fill the fetching queue
        rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9])
        fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)]

        # As we still have some pending tasks, we need to update the
        # completion states to remove the completed tasks
        rayWaitMock.return_value = ([], [6, 7, 8, 9])

        pool.reset_workers(
            [
                tasks[0][0],
                tasks[1][0],
                tasks[2][0],
                tasks[3][0],
                # OH NO! WORKER 4 HAS CRASHED!
                tasks[5][0],
                tasks[6][0],
                tasks[7][0],
                tasks[8][0],
                tasks[9][0],
            ]
        )

        # Fetch the remaining tasks which should already be in the _fetching
        # queue
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, [2, 3, 5])
コード例 #2
0
class AsyncReplayOptimizer(PolicyOptimizer):
    """Main event loop of the Ape-X optimizer (async sampling with replay).

    This class coordinates the data transfers between the learner thread,
    remote workers (Ape-X actors), and replay buffer actors.

    This has two modes of operation:
        - normal replay: replays independent samples.
        - batch replay: simplified mode where entire sample batches are
            replayed. This supports RNNs, but not prioritization.

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 train_batch_size=512,
                 rollout_fragment_length=50,
                 num_replay_buffer_shards=1,
                 max_weight_sync_delay=400,
                 debug=False,
                 batch_replay=False):
        """Initialize an async replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            train_batch_size (int): size of batches to learn on
            rollout_fragment_length (int): size of batches to sample from
                workers.
            num_replay_buffer_shards (int): number of actors to use to store
                replay samples
            max_weight_sync_delay (int): update the weights of a rollout worker
                after collecting this number of timesteps from it
            debug (bool): return extra debug stats
            batch_replay (bool): replay entire sequential batches of
                experiences instead of sampling steps individually
        """
        PolicyOptimizer.__init__(self, workers)

        self.debug = debug
        self.batch_replay = batch_replay
        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.max_weight_sync_delay = max_weight_sync_delay

        self.learner = LearnerThread(self.workers.local_worker())
        self.learner.start()

        if self.batch_replay:
            replay_cls = BatchReplayActor
        else:
            replay_cls = ReplayActor
        self.replay_actors = create_colocated(replay_cls, [
            num_replay_buffer_shards,
            learning_starts,
            buffer_size,
            train_batch_size,
            prioritized_replay_alpha,
            prioritized_replay_beta,
            prioritized_replay_eps,
        ], num_replay_buffer_shards)

        # Stats
        self.timers = {
            k: TimerStat()
            for k in [
                "put_weights", "get_samples", "sample_processing",
                "replay_processing", "update_priorities", "train", "sample"
            ]
        }
        self.num_weight_syncs = 0
        self.num_samples_dropped = 0
        self.learning_started = False

        # Number of worker steps since the last weight update
        self.steps_since_update = {}

        # Otherwise kick of replay tasks for local gradient updates
        self.replay_tasks = TaskPool()
        for ra in self.replay_actors:
            for _ in range(REPLAY_QUEUE_DEPTH):
                self.replay_tasks.add(ra, ra.replay.remote())

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        if self.workers.remote_workers():
            self._set_workers(self.workers.remote_workers())

    @override(PolicyOptimizer)
    def step(self):
        assert self.learner.is_alive()
        assert len(self.workers.remote_workers()) > 0
        start = time.time()
        sample_timesteps, train_timesteps = self._step()
        time_delta = time.time() - start
        self.timers["sample"].push(time_delta)
        self.timers["sample"].push_units_processed(sample_timesteps)
        if train_timesteps > 0:
            self.learning_started = True
        if self.learning_started:
            self.timers["train"].push(time_delta)
            self.timers["train"].push_units_processed(train_timesteps)
        self.num_steps_sampled += sample_timesteps
        self.num_steps_trained += train_timesteps

    @override(PolicyOptimizer)
    def stop(self):
        for r in self.replay_actors:
            r.__ray_terminate__.remote()
        self.learner.stopped = True

    @override(PolicyOptimizer)
    def reset(self, remote_workers):
        self.workers.reset(remote_workers)
        self.sample_tasks.reset_workers(remote_workers)

    @override(PolicyOptimizer)
    def stats(self):
        replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
            self.debug))
        timing = {
            "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
            for k in self.timers
        }
        timing["learner_grad_time_ms"] = round(
            1000 * self.learner.grad_timer.mean, 3)
        timing["learner_dequeue_time_ms"] = round(
            1000 * self.learner.queue_timer.mean, 3)
        stats = {
            "sample_throughput": round(self.timers["sample"].mean_throughput,
                                       3),
            "train_throughput": round(self.timers["train"].mean_throughput, 3),
            "num_weight_syncs": self.num_weight_syncs,
            "num_samples_dropped": self.num_samples_dropped,
            "learner_queue": self.learner.learner_queue_size.stats(),
            "replay_shard_0": replay_stats,
        }
        debug_stats = {
            "timing_breakdown": timing,
            "pending_sample_tasks": self.sample_tasks.count,
            "pending_replay_tasks": self.replay_tasks.count,
        }
        if self.debug:
            stats.update(debug_stats)
        if self.learner.stats:
            stats["learner"] = self.learner.stats
        return dict(PolicyOptimizer.stats(self), **stats)

    # For https://github.com/ray-project/ray/issues/2541 only
    def _set_workers(self, remote_workers):
        self.workers.reset(remote_workers)
        weights = self.workers.local_worker().get_weights()
        for ev in self.workers.remote_workers():
            ev.set_weights.remote(weights)
            self.steps_since_update[ev] = 0
            for _ in range(SAMPLE_QUEUE_DEPTH):
                self.sample_tasks.add(ev, ev.sample_with_count.remote())

    def _step(self):
        sample_timesteps, train_timesteps = 0, 0
        weights = None

        with self.timers["sample_processing"]:
            completed = list(self.sample_tasks.completed())
            # First try a batched ray.get().
            ray_error = None
            try:
                counts = {
                    i: v
                    for i, v in enumerate(
                        ray_get_and_free([c[1][1] for c in completed]))
                }
            # If there are failed workers, try to recover the still good ones
            # (via non-batched ray.get()) and store the first error (to raise
            # later).
            except RayError:
                counts = {}
                for i, c in enumerate(completed):
                    try:
                        counts[i] = ray_get_and_free(c[1][1])
                    except RayError as e:
                        logger.exception(
                            "Error in completed task: {}".format(e))
                        ray_error = ray_error if ray_error is not None else e

            for i, (ev, (sample_batch, count)) in enumerate(completed):
                # Skip failed tasks.
                if i not in counts:
                    continue

                sample_timesteps += counts[i]
                # Send the data to the replay buffer
                random.choice(
                    self.replay_actors).add_batch.remote(sample_batch)

                # Update weights if needed.
                self.steps_since_update[ev] += counts[i]
                if self.steps_since_update[ev] >= self.max_weight_sync_delay:
                    # Note that it's important to pull new weights once
                    # updated to avoid excessive correlation between actors.
                    if weights is None or self.learner.weights_updated:
                        self.learner.weights_updated = False
                        with self.timers["put_weights"]:
                            weights = ray.put(
                                self.workers.local_worker().get_weights())
                    ev.set_weights.remote(weights)
                    self.num_weight_syncs += 1
                    self.steps_since_update[ev] = 0

                # Kick off another sample request.
                self.sample_tasks.add(ev, ev.sample_with_count.remote())

            # Now that all still good tasks have been kicked off again,
            # we can throw the error.
            if ray_error:
                raise ray_error

        with self.timers["replay_processing"]:
            for ra, replay in self.replay_tasks.completed():
                self.replay_tasks.add(ra, ra.replay.remote())
                if self.learner.inqueue.full():
                    self.num_samples_dropped += 1
                else:
                    with self.timers["get_samples"]:
                        samples = ray_get_and_free(replay)
                    # Defensive copy against plasma crashes, see #2610 #3452
                    self.learner.inqueue.put((ra, samples and samples.copy()))

        with self.timers["update_priorities"]:
            while not self.learner.outqueue.empty():
                ra, prio_dict, count = self.learner.outqueue.get()
                ra.update_priorities.remote(prio_dict)
                train_timesteps += count

        return sample_timesteps, train_timesteps
コード例 #3
0
class AsyncReplayOptimizer(PolicyOptimizer):
    """Main event loop of the Ape-X optimizer (async sampling with replay).

    This class coordinates the data transfers between the learner thread,
    remote workers (Ape-X actors), and replay buffer actors.

    This has two modes of operation:
        - normal replay: replays independent samples.
        - batch replay: simplified mode where entire sample batches are
            replayed. This supports RNNs, but not prioritization.

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""

    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 train_batch_size=512,
                 sample_batch_size=50,
                 num_replay_buffer_shards=1,
                 max_weight_sync_delay=400,
                 debug=False,
                 batch_replay=False):
        PolicyOptimizer.__init__(self, workers)

        self.debug = debug
        self.batch_replay = batch_replay
        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.max_weight_sync_delay = max_weight_sync_delay

        self.learner = LearnerThread(self.workers.local_worker())
        self.learner.start()

        if self.batch_replay:
            replay_cls = BatchReplayActor
        else:
            replay_cls = ReplayActor
        self.replay_actors = create_colocated(replay_cls, [
            num_replay_buffer_shards,
            learning_starts,
            buffer_size,
            train_batch_size,
            prioritized_replay_alpha,
            prioritized_replay_beta,
            prioritized_replay_eps,
        ], num_replay_buffer_shards)

        # Stats
        self.timers = {
            k: TimerStat()
            for k in [
                "put_weights", "get_samples", "sample_processing",
                "replay_processing", "update_priorities", "train", "sample"
            ]
        }
        self.num_weight_syncs = 0
        self.num_samples_dropped = 0
        self.learning_started = False

        # Number of worker steps since the last weight update
        self.steps_since_update = {}

        # Otherwise kick of replay tasks for local gradient updates
        self.replay_tasks = TaskPool()
        for ra in self.replay_actors:
            for _ in range(REPLAY_QUEUE_DEPTH):
                self.replay_tasks.add(ra, ra.replay.remote())

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        if self.workers.remote_workers():
            self._set_workers(self.workers.remote_workers())

    @override(PolicyOptimizer)
    def step(self):
        assert self.learner.is_alive()
        assert len(self.workers.remote_workers()) > 0
        start = time.time()
        sample_timesteps, train_timesteps = self._step()
        time_delta = time.time() - start
        self.timers["sample"].push(time_delta)
        self.timers["sample"].push_units_processed(sample_timesteps)
        if train_timesteps > 0:
            self.learning_started = True
        if self.learning_started:
            self.timers["train"].push(time_delta)
            self.timers["train"].push_units_processed(train_timesteps)
        self.num_steps_sampled += sample_timesteps
        self.num_steps_trained += train_timesteps

    @override(PolicyOptimizer)
    def stop(self):
        for r in self.replay_actors:
            r.__ray_terminate__.remote()
        self.learner.stopped = True

    @override(PolicyOptimizer)
    def reset(self, remote_workers):
        self.workers.reset(remote_workers)
        self.sample_tasks.reset_workers(remote_workers)

    @override(PolicyOptimizer)
    def stats(self):
        replay_stats = ray_get_and_free(self.replay_actors[0].stats.remote(
            self.debug))
        timing = {
            "{}_time_ms".format(k): round(1000 * self.timers[k].mean, 3)
            for k in self.timers
        }
        timing["learner_grad_time_ms"] = round(
            1000 * self.learner.grad_timer.mean, 3)
        timing["learner_dequeue_time_ms"] = round(
            1000 * self.learner.queue_timer.mean, 3)
        stats = {
            "sample_throughput": round(self.timers["sample"].mean_throughput,
                                       3),
            "train_throughput": round(self.timers["train"].mean_throughput, 3),
            "num_weight_syncs": self.num_weight_syncs,
            "num_samples_dropped": self.num_samples_dropped,
            "learner_queue": self.learner.learner_queue_size.stats(),
            "replay_shard_0": replay_stats,
        }
        debug_stats = {
            "timing_breakdown": timing,
            "pending_sample_tasks": self.sample_tasks.count,
            "pending_replay_tasks": self.replay_tasks.count,
        }
        if self.debug:
            stats.update(debug_stats)
        if self.learner.stats:
            stats["learner"] = self.learner.stats
        return dict(PolicyOptimizer.stats(self), **stats)

    # For https://github.com/ray-project/ray/issues/2541 only
    def _set_workers(self, remote_workers):
        self.workers.reset(remote_workers)
        weights = self.workers.local_worker().get_weights()
        for ev in self.workers.remote_workers():
            ev.set_weights.remote(weights)
            self.steps_since_update[ev] = 0
            for _ in range(SAMPLE_QUEUE_DEPTH):
                self.sample_tasks.add(ev, ev.sample_with_count.remote())

    def _step(self):
        sample_timesteps, train_timesteps = 0, 0
        weights = None

        with self.timers["sample_processing"]:
            completed = list(self.sample_tasks.completed())
            counts = ray_get_and_free([c[1][1] for c in completed])
            for i, (ev, (sample_batch, count)) in enumerate(completed):
                sample_timesteps += counts[i]

                # Send the data to the replay buffer
                random.choice(
                    self.replay_actors).add_batch.remote(sample_batch)

                # Update weights if needed
                self.steps_since_update[ev] += counts[i]
                if self.steps_since_update[ev] >= self.max_weight_sync_delay:
                    # Note that it's important to pull new weights once
                    # updated to avoid excessive correlation between actors
                    if weights is None or self.learner.weights_updated:
                        self.learner.weights_updated = False
                        with self.timers["put_weights"]:
                            weights = ray.put(
                                self.workers.local_worker().get_weights())
                    ev.set_weights.remote(weights)
                    self.num_weight_syncs += 1
                    self.steps_since_update[ev] = 0

                # Kick off another sample request
                self.sample_tasks.add(ev, ev.sample_with_count.remote())

        with self.timers["replay_processing"]:
            for ra, replay in self.replay_tasks.completed():
                self.replay_tasks.add(ra, ra.replay.remote())
                if self.learner.inqueue.full():
                    self.num_samples_dropped += 1
                else:
                    with self.timers["get_samples"]:
                        samples = ray_get_and_free(replay)
                    # Defensive copy against plasma crashes, see #2610 #3452
                    self.learner.inqueue.put((ra, samples and samples.copy()))

        with self.timers["update_priorities"]:
            while not self.learner.outqueue.empty():
                ra, prio_dict, count = self.learner.outqueue.get()
                ra.update_priorities.remote(prio_dict)
                train_timesteps += count

        return sample_timesteps, train_timesteps
コード例 #4
0
class AggregationWorkerBase:
    """Aggregators should extend from this class."""
    def __init__(self, initial_weights_obj_id, remote_workers,
                 max_sample_requests_in_flight_per_worker, replay_proportion,
                 replay_buffer_num_slots, train_batch_size,
                 rollout_fragment_length):
        """Initialize an aggregator.

        Arguments:
            initial_weights_obj_id (ObjectID): initial worker weights
            remote_workers (list): set of remote workers assigned to this agg
            max_sample_request_in_flight_per_worker (int): max queue size per
                worker
            replay_proportion (float): ratio of replay to sampled outputs
            replay_buffer_num_slots (int): max number of sample batches to
                store in the replay buffer
            train_batch_size (int): size of batches to learn on
            rollout_fragment_length (int): size of batches to sample from
                workers.
        """

        self.broadcasted_weights = initial_weights_obj_id
        self.remote_workers = remote_workers
        self.rollout_fragment_length = rollout_fragment_length
        self.train_batch_size = train_batch_size

        if replay_proportion:
            if (replay_buffer_num_slots * rollout_fragment_length <=
                    train_batch_size):
                raise ValueError(
                    "Replay buffer size is too small to produce train, "
                    "please increase replay_buffer_num_slots.",
                    replay_buffer_num_slots, rollout_fragment_length,
                    train_batch_size)

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        for ev in self.remote_workers:
            ev.set_weights.remote(self.broadcasted_weights)
            for _ in range(max_sample_requests_in_flight_per_worker):
                self.sample_tasks.add(ev, ev.sample.remote())

        self.batch_buffer = []

        self.replay_proportion = replay_proportion
        self.replay_buffer_num_slots = replay_buffer_num_slots
        self.replay_batches = []
        self.replay_index = 0
        self.num_sent_since_broadcast = 0
        self.num_weight_syncs = 0
        self.num_replayed = 0

    @override(Aggregator)
    def iter_train_batches(self, max_yield=999):
        """Iterate over train batches.

        Arguments:
            max_yield (int): Max number of batches to iterate over in this
                cycle. Setting this avoids iter_train_batches returning too
                much data at once.
        """

        for ev, sample_batch in self._augment_with_replay(
                self.sample_tasks.completed_prefetch(blocking_wait=True,
                                                     max_yield=max_yield)):
            sample_batch.decompress_if_needed()
            self.batch_buffer.append(sample_batch)
            if sum(b.count
                   for b in self.batch_buffer) >= self.train_batch_size:
                if len(self.batch_buffer) == 1:
                    # make a defensive copy to avoid sharing plasma memory
                    # across multiple threads
                    train_batch = self.batch_buffer[0].copy()
                else:
                    train_batch = self.batch_buffer[0].concat_samples(
                        self.batch_buffer)
                yield train_batch
                self.batch_buffer = []

            # If the batch was replayed, skip the update below.
            if ev is None:
                continue

            # Put in replay buffer if enabled
            if self.replay_buffer_num_slots > 0:
                if len(self.replay_batches) < self.replay_buffer_num_slots:
                    self.replay_batches.append(sample_batch)
                else:
                    self.replay_batches[self.replay_index] = sample_batch
                    self.replay_index += 1
                    self.replay_index %= self.replay_buffer_num_slots

            ev.set_weights.remote(self.broadcasted_weights)
            self.num_weight_syncs += 1
            self.num_sent_since_broadcast += 1

            # Kick off another sample request
            self.sample_tasks.add(ev, ev.sample.remote())

    @override(Aggregator)
    def stats(self):
        return {
            "num_weight_syncs": self.num_weight_syncs,
            "num_steps_replayed": self.num_replayed,
        }

    @override(Aggregator)
    def reset(self, remote_workers):
        self.sample_tasks.reset_workers(remote_workers)

    def _augment_with_replay(self, sample_futures):
        def can_replay():
            num_needed = int(
                np.ceil(self.train_batch_size / self.rollout_fragment_length))
            return len(self.replay_batches) > num_needed

        for ev, sample_batch in sample_futures:
            sample_batch = ray.get(sample_batch)
            yield ev, sample_batch

            if can_replay():
                f = self.replay_proportion
                while random.random() < f:
                    f -= 1
                    replay_batch = random.choice(self.replay_batches)
                    self.num_replayed += replay_batch.count
                    yield None, replay_batch
コード例 #5
0
class DRAggregatorBase:
    """Aggregators should extend from this class."""
    def __init__(self,
                 initial_weights_obj_id,
                 remote_workers,
                 max_sample_requests_in_flight_per_worker,
                 replay_proportion,
                 replay_buffer_num_slots,
                 train_batch_size,
                 sample_batch_size,
                 sync_sampling=False):
        """Initialize an aggregator.

        Arguments:
            initial_weights_obj_id (ObjectID): initial worker weights
            remote_workers (list): set of remote workers assigned to this agg
            max_sample_request_in_flight_per_worker (int): max queue size per
                worker
            replay_proportion (float): ratio of replay to sampled outputs
            replay_buffer_num_slots (int): max number of sample batches to
                store in the replay buffer
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
        """

        self.broadcasted_weights = initial_weights_obj_id
        self.remote_workers = remote_workers
        self.sample_batch_size = sample_batch_size
        self.train_batch_size = train_batch_size

        if replay_proportion:
            if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
                raise ValueError(
                    "Replay buffer size is too small to produce train, "
                    "please increase replay_buffer_num_slots.",
                    replay_buffer_num_slots, sample_batch_size,
                    train_batch_size)

        self.batch_buffer = []

        self.replay_proportion = replay_proportion
        self.replay_buffer_num_slots = replay_buffer_num_slots
        self.max_sample_requests_in_flight_per_worker = \
            max_sample_requests_in_flight_per_worker
        self.started = False
        self.sample_tasks = TaskPool()
        self.replay_batches = []
        self.replay_index = 0
        self.num_sent_since_broadcast = 0
        self.num_weight_syncs = 0
        self.num_replayed = 0
        self.sample_timesteps = 0

        self.sync_sampling = sync_sampling

    def start(self):
        # Kick off async background sampling
        for ev in self.remote_workers:
            ev.set_weights.remote(self.broadcasted_weights)
            for _ in range(self.max_sample_requests_in_flight_per_worker):
                self.sample_tasks.add(ev, ev.sample.remote())
        self.started = True

    @override(Aggregator)
    def iter_train_batches(self, max_yield=999, _recursive_called=False):
        """Iterate over train batches.

        Arguments:
            force_yield_all (bool): Whether to return all batches until task
                pool is drained.
            max_yield (int): Max number of batches to iterate over in this
                cycle. Setting this avoids iter_train_batches returning too
                much data at once.
        """
        assert self.started
        already_sent_out = False
        # ev is the rollout worker
        for ev, sample_batch in self._augment_with_replay(
                self.sample_tasks.completed_prefetch(blocking_wait=True,
                                                     max_yield=max_yield)):
            sample_batch.decompress_if_needed()
            self.batch_buffer.append(sample_batch)
            if sum(b.count
                   for b in self.batch_buffer) >= self.train_batch_size:
                if len(self.batch_buffer) == 1:
                    # make a defensive copy to avoid sharing plasma memory
                    # across multiple threads
                    train_batch = self.batch_buffer[0].copy()
                else:
                    train_batch = self.batch_buffer[0].concat_samples(
                        self.batch_buffer)
                self.sample_timesteps += train_batch.count

                if self.sync_sampling:
                    # If sync sampling is set, return batch and then stop.
                    # You need to call start at the outside.
                    # return [train_batch]
                    already_sent_out = True
                    yield train_batch
                else:
                    yield train_batch

                self.batch_buffer = []

            # If the batch was replayed, skip the update below.
            if ev is None:
                continue

            # Put in replay buffer if enabled
            if self.replay_buffer_num_slots > 0:
                if len(self.replay_batches) < self.replay_buffer_num_slots:
                    self.replay_batches.append(sample_batch)
                else:
                    self.replay_batches[self.replay_index] = sample_batch
                    self.replay_index += 1
                    self.replay_index %= self.replay_buffer_num_slots

            if already_sent_out and self.sync_sampling:
                continue

            ev.set_weights.remote(self.broadcasted_weights)
            self.num_weight_syncs += 1
            self.num_sent_since_broadcast += 1

            # Kick off another sample request
            self.sample_tasks.add(ev, ev.sample.remote())

        if self.sync_sampling and (not _recursive_called):
            while self.sample_tasks.count > 0:
                # A tricky way to force exhaust the task pool
                for train_batch in self.iter_train_batches(
                        max_yield, _recursive_called=True):
                    yield train_batch

    @override(Aggregator)
    def stats(self):
        return {
            "num_weight_syncs": self.num_weight_syncs,
            "num_steps_replayed": self.num_replayed,
            "sample_timesteps": self.sample_timesteps
        }

    @override(Aggregator)
    def reset(self, remote_workers):
        self.sample_tasks.reset_workers(remote_workers)

    def _augment_with_replay(self, sample_futures):
        def can_replay():
            num_needed = int(
                np.ceil(self.train_batch_size / self.sample_batch_size))
            return len(self.replay_batches) > num_needed

        for ev, sample_batch in sample_futures:
            sample_batch = ray_get_and_free(sample_batch)
            yield ev, sample_batch

            if can_replay():
                f = self.replay_proportion
                while random.random() < f:
                    f -= 1
                    replay_batch = random.choice(self.replay_batches)
                    self.num_replayed += replay_batch.count
                    yield None, replay_batch