Exemple #1
0
class AggregatorWorker:
    """A worker for doing tree aggregation of collected episodes"""
    def __init__(self, config: AlgorithmConfigDict):
        self.config = config
        self._mixin_buffer = MixInMultiAgentReplayBuffer(
            capacity=(self.config["replay_buffer_num_slots"]
                      if self.config["replay_buffer_num_slots"] > 0 else 1),
            replay_ratio=self.config["replay_ratio"],
            replay_mode=ReplayMode.LOCKSTEP,
        )

    def process_episodes(self, batch: SampleBatchType) -> SampleBatchType:
        batch = batch.decompress_if_needed()
        self._mixin_buffer.add_batch(batch)
        processed_batches = self._mixin_buffer.replay(_ALL_POLICIES)
        return processed_batches

    def apply(
        self,
        func: Callable[["AggregatorWorker", Optional[Any], Optional[Any]], T],
        *_args,
        **kwargs,
    ) -> T:
        """Calls the given function with this AggregatorWorker instance."""
        return func(self, *_args, **kwargs)

    def get_host(self) -> str:
        return platform.node()
Exemple #2
0
 def __init__(self, config: TrainerConfigDict):
     self.config = config
     self._mixin_buffer = MixInMultiAgentReplayBuffer(
         capacity=(self.config["replay_buffer_num_slots"]
                   if self.config["replay_buffer_num_slots"] > 0 else 1),
         replay_ratio=self.config["replay_ratio"],
     )
Exemple #3
0
 def __init__(self, config: AlgorithmConfigDict):
     self.config = config
     self._mixin_buffer = MixInMultiAgentReplayBuffer(
         capacity=(self.config["replay_buffer_num_slots"]
                   if self.config["replay_buffer_num_slots"] > 0 else 1),
         replay_ratio=self.config["replay_ratio"],
         replay_mode=ReplayMode.LOCKSTEP,
     )
Exemple #4
0
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)
        self.remote_sampling_requests_in_flight: DefaultDict[
            ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)

        if self.config["_disable_execution_plan_api"]:
            # Create extra aggregation workers and assign each rollout worker to
            # one of them.
            self.batches_to_place_on_learner = []
            self.batch_being_built = []
            if self.config["num_aggregation_workers"] > 0:
                # This spawns `num_aggregation_workers` actors that aggregate
                # experiences coming from RolloutWorkers in parallel. We force
                # colocation on the same node (localhost) to maximize data bandwidth
                # between them and the learner.
                localhost = platform.node()
                assert localhost != "", (
                    "ERROR: Cannot determine local node name! "
                    "`platform.node()` returned empty string.")
                all_co_located = create_colocated_actors(
                    actor_specs=[
                        # (class, args, kwargs={}, count=1)
                        (
                            AggregatorWorker,
                            [
                                self.config,
                            ],
                            {},
                            self.config["num_aggregation_workers"],
                        )
                    ],
                    node=localhost,
                )
                self.aggregator_workers = [
                    actor for actor_groups in all_co_located
                    for actor in actor_groups
                ]
                self.remote_aggregator_requests_in_flight: DefaultDict[
                    ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)

            else:
                # Create our local mixin buffer if the num of aggregation workers is 0.
                self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
                    capacity=(self.config["replay_buffer_num_slots"]
                              if self.config["replay_buffer_num_slots"] > 0
                              else 1),
                    replay_ratio=self.config["replay_ratio"],
                )

            # Create and start the learner thread.
            self._learner_thread = make_learner_thread(
                self.workers.local_worker(), self.config)
            self._learner_thread.start()
            self.workers_that_need_updates = set()
Exemple #5
0
class Impala(Algorithm):
    """Importance weighted actor/learner architecture (IMPALA) Algorithm

    == Overview of data flow in IMPALA ==
    1. Policy evaluation in parallel across `num_workers` actors produces
       batches of size `rollout_fragment_length * num_envs_per_worker`.
    2. If enabled, the replay buffer stores and produces batches of size
       `rollout_fragment_length * num_envs_per_worker`.
    3. If enabled, the minibatch ring buffer stores and replays batches of
       size `train_batch_size` up to `num_sgd_iter` times per batch.
    4. The learner thread executes data parallel SGD across `num_gpus` GPUs
       on batches of size `train_batch_size`.
    """
    @classmethod
    @override(Algorithm)
    def get_default_config(cls) -> AlgorithmConfigDict:
        return ImpalaConfig().to_dict()

    @override(Algorithm)
    def get_default_policy_class(
            self,
            config: PartialAlgorithmConfigDict) -> Optional[Type[Policy]]:
        if config["framework"] == "torch":
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_torch_policy import (
                    ImpalaTorchPolicy, )

                return ImpalaTorchPolicy
            else:
                from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy

                return A3CTorchPolicy
        elif config["framework"] == "tf":
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF1Policy

                return ImpalaTF1Policy
            else:
                from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy

                return A3CTFPolicy
        else:
            if config["vtrace"]:
                from ray.rllib.algorithms.impala.impala_tf_policy import ImpalaTF2Policy

                return ImpalaTF2Policy
            else:
                from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy

                return A3CTFPolicy

    @override(Algorithm)
    def validate_config(self, config):
        # Call the super class' validation method first.
        super().validate_config(config)

        # Check the IMPALA specific config.

        if config["num_data_loader_buffers"] != DEPRECATED_VALUE:
            deprecation_warning("num_data_loader_buffers",
                                "num_multi_gpu_tower_stacks",
                                error=False)
            config["num_multi_gpu_tower_stacks"] = config[
                "num_data_loader_buffers"]

        if config["entropy_coeff"] < 0.0:
            raise ValueError("`entropy_coeff` must be >= 0.0!")

        # Check whether worker to aggregation-worker ratio makes sense.
        if config["num_aggregation_workers"] > config["num_workers"]:
            raise ValueError(
                "`num_aggregation_workers` must be smaller than or equal "
                "`num_workers`! Aggregation makes no sense otherwise.")
        elif config["num_aggregation_workers"] > config["num_workers"] / 2:
            logger.warning(
                "`num_aggregation_workers` should be significantly smaller "
                "than `num_workers`! Try setting it to 0.5*`num_workers` or "
                "less.")

        # If two separate optimizers/loss terms used for tf, must also set
        # `_tf_policy_handles_more_than_one_loss` to True.
        if config["_separate_vf_optimizer"] is True:
            # Only supported to tf so far.
            # TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
            #  models to return separate sets of weights in order to create
            #  the different torch optimizers).
            if config["framework"] not in ["tf", "tf2", "tfe"]:
                raise ValueError(
                    "`_separate_vf_optimizer` only supported to tf so far!")
            if config["_tf_policy_handles_more_than_one_loss"] is False:
                logger.warning(
                    "`_tf_policy_handles_more_than_one_loss` must be set to "
                    "True, for TFPolicy to support more than one loss "
                    "term/optimizer! Auto-setting it to True.")
                config["_tf_policy_handles_more_than_one_loss"] = True

    @override(Algorithm)
    def setup(self, config: PartialAlgorithmConfigDict):
        super().setup(config)

        if self.config["_disable_execution_plan_api"]:
            # Create extra aggregation workers and assign each rollout worker to
            # one of them.
            self.batches_to_place_on_learner = []
            self.batch_being_built = []
            if self.config["num_aggregation_workers"] > 0:
                # This spawns `num_aggregation_workers` actors that aggregate
                # experiences coming from RolloutWorkers in parallel. We force
                # colocation on the same node (localhost) to maximize data bandwidth
                # between them and the learner.
                localhost = platform.node()
                assert localhost != "", (
                    "ERROR: Cannot determine local node name! "
                    "`platform.node()` returned empty string.")
                all_co_located = create_colocated_actors(
                    actor_specs=[
                        # (class, args, kwargs={}, count=1)
                        (
                            AggregatorWorker,
                            [
                                self.config,
                            ],
                            {},
                            self.config["num_aggregation_workers"],
                        )
                    ],
                    node=localhost,
                )
                self._aggregator_workers = [
                    actor for actor_groups in all_co_located
                    for actor in actor_groups
                ]
                self._aggregator_actor_manager = AsyncRequestsManager(
                    self._aggregator_workers,
                    max_remote_requests_in_flight_per_worker=self.
                    config["max_requests_in_flight_per_aggregator_worker"],
                    ray_wait_timeout_s=self.
                    config["timeout_s_aggregator_manager"],
                )

            else:
                # Create our local mixin buffer if the num of aggregation workers is 0.
                self.local_mixin_buffer = MixInMultiAgentReplayBuffer(
                    capacity=(self.config["replay_buffer_num_slots"]
                              if self.config["replay_buffer_num_slots"] > 0
                              else 1),
                    replay_ratio=self.config["replay_ratio"],
                    replay_mode=ReplayMode.LOCKSTEP,
                )

            self._sampling_actor_manager = AsyncRequestsManager(
                self.workers.remote_workers(),
                max_remote_requests_in_flight_per_worker=self.
                config["max_requests_in_flight_per_sampler_worker"],
                return_object_refs=True,
                ray_wait_timeout_s=self.config["timeout_s_sampler_manager"],
            )

            # Create and start the learner thread.
            self._learner_thread = make_learner_thread(
                self.workers.local_worker(), self.config)
            self._learner_thread.start()
            self.workers_that_need_updates = set()

    @override(Algorithm)
    def training_step(self) -> ResultDict:
        unprocessed_sample_batches = self.get_samples_from_workers()

        self.workers_that_need_updates |= unprocessed_sample_batches.keys()

        if self.config["num_aggregation_workers"] > 0:
            batch = self.process_experiences_tree_aggregation(
                unprocessed_sample_batches)
        else:
            batch = self.process_experiences_directly(
                unprocessed_sample_batches)

        self.concatenate_batches_and_pre_queue(batch)
        self.place_processed_samples_on_learner_queue()
        train_results = self.process_trained_results()

        self.update_workers_if_necessary()

        return train_results

    @staticmethod
    @override(Algorithm)
    def execution_plan(workers, config, **kwargs):
        assert (
            len(kwargs) == 0
        ), "IMPALA execution_plan does NOT take any additional parameters"

        if config["num_aggregation_workers"] > 0:
            train_batches = gather_experiences_tree_aggregation(
                workers, config)
        else:
            train_batches = gather_experiences_directly(workers, config)

        # Start the learner thread.
        learner_thread = make_learner_thread(workers.local_worker(), config)
        learner_thread.start()

        # This sub-flow sends experiences to the learner.
        enqueue_op = train_batches.for_each(Enqueue(learner_thread.inqueue))
        # Only need to update workers if there are remote workers.
        if workers.remote_workers():
            enqueue_op = enqueue_op.zip_with_source_actor().for_each(
                BroadcastUpdateLearnerWeights(
                    learner_thread,
                    workers,
                    broadcast_interval=config["broadcast_interval"],
                ))

        def record_steps_trained(item):
            count, fetches, _ = item
            metrics = _get_shared_metrics()
            # Manually update the steps trained counter since the learner
            # thread is executing outside the pipeline.
            metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
            metrics.counters[STEPS_TRAINED_COUNTER] += count
            return item

        # This sub-flow updates the steps trained counter based on learner
        # output.
        dequeue_op = Dequeue(
            learner_thread.outqueue,
            check=learner_thread.is_alive).for_each(record_steps_trained)

        merged_op = Concurrently([enqueue_op, dequeue_op],
                                 mode="async",
                                 output_indexes=[1])

        # Callback for APPO to use to update KL, target network periodically.
        # The input to the callback is the learner fetches dict.
        if config["after_train_step"]:
            merged_op = merged_op.for_each(lambda t: t[1]).for_each(
                config["after_train_step"](workers, config))

        return StandardMetricsReporting(merged_op, workers, config).for_each(
            learner_thread.add_learner_metrics)

    @classmethod
    @override(Algorithm)
    def default_resource_request(cls, config):
        cf = dict(cls.get_default_config(), **config)

        eval_config = cf["evaluation_config"]

        # Return PlacementGroupFactory containing all needed resources
        # (already properly defined as device bundles).
        return PlacementGroupFactory(
            bundles=[{
                # Driver + Aggregation Workers:
                # Force to be on same node to maximize data bandwidth
                # between aggregation workers and the learner (driver).
                # Aggregation workers tree-aggregate experiences collected
                # from RolloutWorkers (n rollout workers map to m
                # aggregation workers, where m < n) and always use 1 CPU
                # each.
                "CPU":
                cf["num_cpus_for_driver"] + cf["num_aggregation_workers"],
                "GPU":
                0 if cf["_fake_gpus"] else cf["num_gpus"],
            }] + [
                {
                    # RolloutWorkers.
                    "CPU": cf["num_cpus_per_worker"],
                    "GPU": cf["num_gpus_per_worker"],
                    **cf["custom_resources_per_worker"],
                } for _ in range(cf["num_workers"])
            ] + ([
                {
                    # Evaluation (remote) workers.
                    # Note: The local eval worker is located on the driver
                    # CPU or not even created iff >0 eval workers.
                    "CPU":
                    eval_config.get("num_cpus_per_worker",
                                    cf["num_cpus_per_worker"]),
                    "GPU":
                    eval_config.get("num_gpus_per_worker",
                                    cf["num_gpus_per_worker"]),
                    **eval_config.get(
                        "custom_resources_per_worker",
                        cf["custom_resources_per_worker"],
                    ),
                } for _ in range(cf["evaluation_num_workers"])
            ] if cf["evaluation_interval"] else []),
            strategy=config.get("placement_strategy", "PACK"),
        )

    def concatenate_batches_and_pre_queue(self, batches: List[SampleBatch]):
        """Concatenate batches that are being returned from rollout workers

        Args:
            batches: batches of experiences from rollout workers

        """
        def aggregate_into_larger_batch():
            if (sum(b.count for b in self.batch_being_built) >=
                    self.config["train_batch_size"]):
                batch_to_add = SampleBatch.concat_samples(
                    self.batch_being_built)
                self.batches_to_place_on_learner.append(batch_to_add)
                self.batch_being_built = []

        for batch in batches:
            self.batch_being_built.append(batch)
            aggregate_into_larger_batch()

    def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]:
        # Perform asynchronous sampling on all (remote) rollout workers.
        if self.workers.remote_workers():
            self._sampling_actor_manager.call_on_all_available(
                lambda worker: worker.sample())
            sample_batches: Dict[
                ActorHandle,
                List[ObjectRef]] = self._sampling_actor_manager.get_ready()
        else:
            # only sampling on the local worker
            sample_batches = {
                self.workers.local_worker():
                [self.workers.local_worker().sample()]
            }
        return sample_batches

    def place_processed_samples_on_learner_queue(self) -> None:
        self._counters["num_samples_added_to_queue"] = 0

        while self.batches_to_place_on_learner:
            batch = self.batches_to_place_on_learner[0]
            try:
                self._learner_thread.inqueue.put(batch, block=False)
                self.batches_to_place_on_learner.pop(0)
                self._counters[NUM_ENV_STEPS_SAMPLED] += batch.count
                self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
                self._counters["num_samples_added_to_queue"] = batch.count
            except queue.Full:
                self._counters["num_times_learner_queue_full"] += 1

    def process_trained_results(self) -> ResultDict:
        # Get learner outputs/stats from output queue.
        final_learner_info = {}
        learner_infos = []
        num_env_steps_trained = 0
        num_agent_steps_trained = 0

        for _ in range(self._learner_thread.outqueue.qsize()):
            if self._learner_thread.is_alive():
                (
                    env_steps,
                    agent_steps,
                    learner_results,
                ) = self._learner_thread.outqueue.get(timeout=0.001)
                num_env_steps_trained += env_steps
                num_agent_steps_trained += agent_steps
                if learner_results:
                    learner_infos.append(learner_results)
            else:
                raise RuntimeError("The learner thread died in while training")
        if not learner_infos:
            final_learner_info = copy.deepcopy(
                self._learner_thread.learner_info)
        else:
            builder = LearnerInfoBuilder()
            for info in learner_infos:
                builder.add_learn_on_batch_results_multi_agent(info)
            final_learner_info = builder.finalize()

        # Update the steps trained counters.
        self._counters[
            STEPS_TRAINED_THIS_ITER_COUNTER] = num_agent_steps_trained
        self._counters[NUM_ENV_STEPS_TRAINED] += num_env_steps_trained
        self._counters[NUM_AGENT_STEPS_TRAINED] += num_agent_steps_trained

        return final_learner_info

    def process_experiences_directly(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        processed_batches = []
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        if not batches:
            return processed_batches
        if batches and isinstance(batches[0], ray.ObjectRef):
            batches = ray.get(batches)
        for batch in batches:
            batch = batch.decompress_if_needed()
            self.local_mixin_buffer.add_batch(batch)
            batch = self.local_mixin_buffer.replay(_ALL_POLICIES)
            if batch:
                processed_batches.append(batch)
        return processed_batches

    def process_experiences_tree_aggregation(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        ready_processed_batches = []
        for batch in batches:
            self._aggregator_actor_manager.call(
                lambda actor, b: actor.process_episodes(b),
                fn_kwargs={"b": batch})

        waiting_processed_sample_batches: Dict[
            ActorHandle,
            List[ObjectRef]] = self._aggregator_actor_manager.get_ready()
        for ready_sub_batches in waiting_processed_sample_batches.values():
            ready_processed_batches.extend(ready_sub_batches)

        return ready_processed_batches

    def update_workers_if_necessary(self) -> None:
        # Only need to update workers if there are remote workers.
        global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_TRAINED]}
        self._counters["steps_since_broadcast"] += 1
        if (self.workers.remote_workers()
                and self._counters["steps_since_broadcast"] >=
                self.config["broadcast_interval"]
                and self.workers_that_need_updates):
            weights = ray.put(self.workers.local_worker().get_weights())
            self._counters["steps_since_broadcast"] = 0
            self._learner_thread.weights_updated = False
            self._counters["num_weight_broadcasts"] += 1

            for worker in self.workers_that_need_updates:
                worker.set_weights.remote(weights, global_vars)
            self.workers_that_need_updates = set()

        # Update global vars of the local worker.
        self.workers.local_worker().set_global_vars(global_vars)

    @override(Algorithm)
    def on_worker_failures(self, removed_workers: List[ActorHandle],
                           new_workers: List[ActorHandle]):
        """Handle the failures of remote sampling workers

        Args:
            removed_workers: removed worker ids.
            new_workers: ids of newly created workers.
        """
        if self.config["_disable_execution_plan_api"]:
            self._sampling_actor_manager.remove_workers(
                removed_workers, remove_in_flight_requests=True)
            self._sampling_actor_manager.add_workers(new_workers)

    @override(Algorithm)
    def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
        result = super()._compile_iteration_results(
            step_ctx=step_ctx, iteration_results=iteration_results)
        result = self._learner_thread.add_learner_metrics(
            result, overwrite_learner_info=False)
        return result
Exemple #6
0
    def test_mixin_sampling(self):
        # 50% replay ratio.
        buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.5)
        # Add a new batch.
        batch = self._generate_data()
        buffer.add_batch(batch)
        # Expect at least 1 sample to be returned.
        sample = buffer.replay()
        self.assertTrue(len(sample) >= 1)
        # If we insert and replay n times, expect roughly return batches of
        # len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
        # on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add_batch(batch)
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 2.0)

        # 33% replay ratio.
        buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.333)
        # Expect exactly 0 samples to be returned (buffer empty).
        sample = buffer.replay()
        self.assertTrue(sample is None)
        # Add a new batch.
        batch = self._generate_data()
        buffer.add_batch(batch)
        # Expect at least 1 sample to be returned.
        sample = buffer.replay()
        self.assertTrue(len(sample) >= 1)
        # If we insert-2x and replay n times, expect roughly return batches of
        # len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
        # on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add_batch(batch)
            buffer.add_batch(batch)
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)

        # If we insert-1x and replay n times, expect roughly return batches of
        # len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
        # samples on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add_batch(batch)
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)

        # 90% replay ratio.
        buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.9)
        # Expect exactly 0 samples to be returned (buffer empty).
        sample = buffer.replay()
        self.assertTrue(sample is None)
        # Add a new batch.
        batch = self._generate_data()
        buffer.add_batch(batch)
        # Expect at least 2 samples to be returned (new one plus at least one
        # replay sample).
        sample = buffer.replay()
        self.assertTrue(len(sample) >= 2)
        # If we insert and replay n times, expect roughly return batches of
        # len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
        # samples on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add_batch(batch)
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)

        # 0% replay ratio -> Only new samples.
        buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity,
                                             replay_ratio=0.0)
        # Add a new batch.
        batch = self._generate_data()
        buffer.add_batch(batch)
        # Expect exactly 1 sample to be returned.
        sample = buffer.replay()
        self.assertTrue(len(sample) == 1)
        # Expect exactly 0 sample to be returned (nothing new to be returned;
        # no replay allowed (replay_ratio=0.0)).
        sample = buffer.replay()
        self.assertTrue(sample is None)
        # If we insert and replay n times, expect roughly return batches of
        # len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
        # on average in each returned value).
        results = []
        for _ in range(100):
            buffer.add_batch(batch)
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 1.0)

        # 100% replay ratio -> Only new samples.
        buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity,
                                             replay_ratio=1.0)
        # Expect exactly 0 samples to be returned (buffer empty).
        sample = buffer.replay()
        self.assertTrue(sample is None)
        # Add a new batch.
        batch = self._generate_data()
        buffer.add_batch(batch)
        # Expect exactly 1 sample to be returned (the new batch).
        sample = buffer.replay()
        self.assertTrue(len(sample) == 1)
        # Another replay -> Expect exactly 1 sample to be returned.
        sample = buffer.replay()
        self.assertTrue(len(sample) == 1)
        # If we replay n times, expect roughly return batches of
        # len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
        # on average in each returned value).
        results = []
        for _ in range(100):
            sample = buffer.replay()
            results.append(len(sample))
        self.assertAlmostEqual(np.mean(results), 1.0)