def new_buffer():
     self.underlying_buffer_call_args = {}
     return ReplayBuffer(
         self.capacity,
         storage_unit=StorageUnit.FRAGMENTS,
     )
    def __init__(self,
                 capacity: int = 10000,
                 storage_unit: str = "timesteps",
                 num_shards: int = 1,
                 learning_starts: int = 1000,
                 replay_batch_size: int = 1,
                 prioritized_replay_alpha: float = 0.6,
                 prioritized_replay_beta: float = 0.4,
                 prioritized_replay_eps: float = 1e-6,
                 replay_mode: str = "independent",
                 replay_sequence_length: int = 1,
                 replay_burn_in: int = 0,
                 replay_zero_init_states: bool = True,
                 **kwargs):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            num_shards: The number of buffer shards that exist in total
                (including this one).
            learning_starts: Number of timesteps after which a call to
                `replay()` will yield samples (before that, `replay()` will
                return None).
            capacity: The capacity of the buffer. Note that when
                `replay_sequence_length` > 1, this is the number of sequences
                (not single timesteps) stored.
            replay_batch_size: The batch size to be sampled (in timesteps).
                Note that if `replay_sequence_length` > 1,
                `self.replay_batch_size` will be set to the number of
                sequences sampled (B).
            prioritized_replay_alpha: Alpha parameter for a prioritized
                replay buffer. Use 0.0 for no prioritization.
            prioritized_replay_beta: Beta parameter for a prioritized
                replay buffer.
            prioritized_replay_eps: Epsilon parameter for a prioritized
                replay buffer.
            replay_mode: One of "independent" or "lockstep". Determined,
                whether in the multiagent case, sampling is done across all
                agents/policies equally.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer.
            replay_burn_in: The burn-in length in case
                `replay_sequence_length` > 0. This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
            **kwargs: Forward compatibility kwargs.
        """
        shard_capacity = capacity // num_shards
        ReplayBuffer.__init__(self, shard_capacity, storage_unit)

        self.replay_starts = learning_starts // num_shards
        self.replay_batch_size = replay_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.replay_mode = replay_mode
        self.replay_sequence_length = replay_sequence_length
        self.replay_burn_in = replay_burn_in
        self.replay_zero_init_states = replay_zero_init_states

        if replay_sequence_length > 1:
            self.replay_batch_size = int(
                max(1, replay_batch_size // replay_sequence_length))
            logger.info(
                "Since replay_sequence_length={} and replay_batch_size={}, "
                "we will replay {} sequences at a time.".format(
                    replay_sequence_length, replay_batch_size,
                    self.replay_batch_size))

        if replay_mode not in ["lockstep", "independent"]:
            raise ValueError("Unsupported replay mode: {}".format(replay_mode))

        def new_buffer():
            if prioritized_replay_alpha == 0.0:
                return ReplayBuffer(self.capacity)
            else:
                return PrioritizedReplayBuffer(self.capacity,
                                               alpha=prioritized_replay_alpha)

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics.
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
        self._num_added = 0

        # Make externally accessible for testing.
        global _local_replay_buffer
        _local_replay_buffer = self
        # If set, return this instead of the usual data for testing.
        self._fake_batch = None
 def new_buffer():
     if prioritized_replay_alpha == 0.0:
         return ReplayBuffer(self.capacity)
     else:
         return PrioritizedReplayBuffer(self.capacity,
                                        alpha=prioritized_replay_alpha)
Exemple #4
0
    def test_multi_agent_batches(self):
        """Tests buffer with storage of MultiAgentBatches."""
        self.batch_id = 0

        def _add_multi_agent_batch_to_buffer(buffer,
                                             num_policies,
                                             num_batches=5,
                                             seq_lens=False,
                                             **kwargs):
            def _generate_data(policy_id):
                batch = SampleBatch({
                    SampleBatch.T: [0, 1],
                    SampleBatch.ACTIONS:
                    2 * [np.random.choice([0, 1])],
                    SampleBatch.REWARDS:
                    2 * [np.random.rand()],
                    SampleBatch.OBS:
                    2 * [np.random.random((4, ))],
                    SampleBatch.NEXT_OBS:
                    2 * [np.random.random((4, ))],
                    SampleBatch.DONES: [False, True],
                    SampleBatch.EPS_ID:
                    2 * [self.batch_id],
                    SampleBatch.AGENT_INDEX:
                    2 * [0],
                    SampleBatch.SEQ_LENS: [2],
                    "batch_id":
                    2 * [self.batch_id],
                    "policy_id":
                    2 * [policy_id],
                })
                if not seq_lens:
                    del batch[SampleBatch.SEQ_LENS]
                self.batch_id += 1
                return batch

            for i in range(num_batches):
                # genera a few policy batches
                policy_batches = {
                    idx: _generate_data(idx)
                    for idx, _ in enumerate(range(num_policies))
                }
                batch = MultiAgentBatch(policy_batches, num_batches * 2)
                buffer.add(batch, **kwargs)

        buffer = ReplayBuffer(capacity=100, storage_unit="timesteps")

        # Test add/sample
        _add_multi_agent_batch_to_buffer(buffer, num_policies=2, num_batches=2)

        # After adding a single batch to a buffer, it should not be full
        assert len(buffer) == 2
        assert buffer._num_timesteps_added == 8
        assert buffer._num_timesteps_added_wrap == 8
        assert buffer._next_idx == 2
        assert buffer._eviction_started is False

        # Sampling three times should yield 3 batches of 5 timesteps each
        buffer.sample(3)
        assert buffer._num_timesteps_sampled == 12

        _add_multi_agent_batch_to_buffer(buffer,
                                         batch_size=100,
                                         num_policies=3,
                                         num_batches=3)

        # After adding two more batches, the buffer should be full
        assert len(buffer) == 5
        assert buffer._num_timesteps_added == 26
        assert buffer._num_timesteps_added_wrap == 26
        assert buffer._next_idx == 5
Exemple #5
0
    def __init__(self,
                 capacity: int = 10000,
                 storage_unit: str = "timesteps",
                 num_shards: int = 1,
                 learning_starts: int = 1000,
                 replay_mode: str = "independent",
                 replay_sequence_override: bool = True,
                 replay_sequence_length: int = 1,
                 replay_burn_in: int = 0,
                 replay_zero_init_states: bool = True,
                 underlying_buffer_config: dict = None,
                 **kwargs):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            capacity: The capacity of the buffer, measured in `storage_unit`.
            storage_unit: Either 'timesteps', 'sequences' or
                'episodes'. Specifies how experiences are stored. If they
                are stored in episodes, replay_sequence_length is ignored.
            num_shards: The number of buffer shards that exist in total
                (including this one).
            learning_starts: Number of timesteps after which a call to
                `sample()` will yield samples (before that, `sample()` will
                return None).
            replay_mode: One of "independent" or "lockstep". Determines,
                whether batches are sampled independently or to an equal
                amount.
            replay_sequence_override: If True, ignore sequences found in incoming
                batches, slicing them into sequences as specified by
                `replay_sequence_length` and `replay_sequence_burn_in`. This only has
                an effect if storage_unit is `sequences`.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer. This
                only has an effect if storage_unit is 'timesteps'.
            replay_burn_in: This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout. This only has an effect
                if storage_unit is `sequences`.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
            underlying_buffer_config: A config that contains all necessary
                constructor arguments and arguments for methods to call on
                the underlying buffers.
            ``**kwargs``: Forward compatibility kwargs.
        """
        shard_capacity = capacity // num_shards
        ReplayBuffer.__init__(self, capacity, storage_unit)

        # If the user provides an underlying buffer config, we use to
        # instantiate and interact with underlying buffers
        self.underlying_buffer_config = underlying_buffer_config
        if self.underlying_buffer_config is not None:
            self.underlying_buffer_call_args = self.underlying_buffer_config
        else:
            self.underlying_buffer_call_args = {}
        self.replay_sequence_override = replay_sequence_override
        self.replay_starts = learning_starts // num_shards
        self.replay_mode = replay_mode
        self.replay_sequence_length = replay_sequence_length
        self.replay_burn_in = replay_burn_in
        self.replay_zero_init_states = replay_zero_init_states
        self.replay_sequence_override = replay_sequence_override

        if (replay_sequence_length > 1
                and self.storage_unit is not StorageUnit.SEQUENCES):
            logger.warning(
                "MultiAgentReplayBuffer configured with "
                "`replay_sequence_length={}`, but `storage_unit={}`. "
                "replay_sequence_length will be ignored and set to 1.".format(
                    replay_sequence_length, storage_unit))
            self.replay_sequence_length = 1

        if replay_sequence_length == 1 and self.storage_unit is StorageUnit.SEQUENCES:
            logger.warning(
                "MultiAgentReplayBuffer configured with "
                "`replay_sequence_length={}`, but `storage_unit={}`. "
                "This will result in sequences equal to timesteps.".format(
                    replay_sequence_length, storage_unit))

        if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]:
            self.replay_mode = ReplayMode.LOCKSTEP
            if self.storage_unit in [
                    StorageUnit.EPISODES, StorageUnit.SEQUENCES
            ]:
                raise ValueError("MultiAgentReplayBuffer does not support "
                                 "lockstep mode with storage unit `episodes`"
                                 "or `sequences`.")
        elif replay_mode in ["independent", ReplayMode.INDEPENDENT]:
            self.replay_mode = ReplayMode.INDEPENDENT
        else:
            raise ValueError("Unsupported replay mode: {}".format(replay_mode))

        if self.underlying_buffer_config:
            ctor_args = {
                **{
                    "capacity": shard_capacity,
                    "storage_unit": StorageUnit.FRAGMENTS
                },
                **self.underlying_buffer_config,
            }

            def new_buffer():
                return from_config(self.underlying_buffer_config["type"],
                                   ctor_args)

        else:
            # Default case
            def new_buffer():
                self.underlying_buffer_call_args = {}
                return ReplayBuffer(
                    self.capacity,
                    storage_unit=StorageUnit.FRAGMENTS,
                )

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics.
        self.add_batch_timer = _Timer()
        self.replay_timer = _Timer()
        self._num_added = 0
Exemple #6
0
    def test_episodes_unit(self):
        """Tests adding, sampling, and eviction of episodes."""
        buffer = ReplayBuffer(capacity=18, storage_unit="episodes")

        batches = [
            SampleBatch({
                SampleBatch.T: [0, 1, 2, 3],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, False, False, True],
                SampleBatch.SEQ_LENS: [4],
                SampleBatch.EPS_ID: 4 * [i],
            }) for i in range(3)
        ]

        batches.append(
            SampleBatch({
                SampleBatch.T: [0, 1, 0, 1],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, True, False, True],
                SampleBatch.SEQ_LENS: [2, 2],
                SampleBatch.EPS_ID: [3, 3, 4, 4],
            }))

        for batch in batches:
            buffer.add(batch)

        num_sampled_dict = {_id: 0 for _id in range(5)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # All episodes, even though in different batches should be sampled
        # equally often
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 1 / 5, 1 / 5, 1 / 5],
            atol=0.1,
        )

        # Episode 6 is not entirely inside this batch, it should not be added
        # to the buffer
        buffer.add(
            SampleBatch({
                SampleBatch.T: [0, 1, 0, 1],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, True, False, False],
                SampleBatch.SEQ_LENS: [2, 2],
                SampleBatch.EPS_ID: [5, 5, 6, 6],
            }))

        num_sampled_dict = {_id: 0 for _id in range(7)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # Episode 7 should be dropped for not ending inside the batch
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0],
            atol=0.1,
        )

        # Add another batch to evict the first batch
        buffer.add(
            SampleBatch({
                SampleBatch.T: [0, 1, 2, 3],
                SampleBatch.ACTIONS: 4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS: 4 * [np.random.rand()],
                SampleBatch.DONES: [False, False, False, True],
                SampleBatch.SEQ_LENS: [4],
                SampleBatch.EPS_ID: 4 * [7],
            }))

        # After adding 1 more batch, eviction has started with 24
        # timesteps added in total, 2 of which were discarded
        assert len(buffer) == 6
        assert buffer._num_timesteps_added == 4 * 6 - 2
        assert buffer._num_timesteps_added_wrap == 4
        assert buffer._next_idx == 1
        assert buffer._eviction_started is True

        num_sampled_dict = {_id: 0 for _id in range(8)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample[SampleBatch.EPS_ID][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [0, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 1 / 6, 0, 1 / 6],
            atol=0.1,
        )
Exemple #7
0
    def test_sequences_unit(self):
        """Tests adding, sampling and eviction of sequences."""
        buffer = ReplayBuffer(capacity=10, storage_unit="sequences")

        batches = [
            SampleBatch({
                SampleBatch.T:
                i * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                i * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                i * [np.random.rand()],
                SampleBatch.DONES:
                i * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [i],
                "batch_id":
                i * [i],
            }) for i in range(1, 4)
        ]

        batches.append(
            SampleBatch({
                SampleBatch.T:
                4 * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                4 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                4 * [np.random.rand()],
                SampleBatch.DONES:
                4 * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [2, 2],
                "batch_id":
                4 * [4],
            }))

        for batch in batches:
            buffer.add(batch)

        num_sampled_dict = {_id: 0 for _id in range(1, 5)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample["batch_id"][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        # Out of five sequences, we want to sequences from the last batch to
        # be sampled twice as often, because they are stored separately
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 1 / 5, 2 / 5],
            atol=0.1,
        )

        # Add another batch to evict
        buffer.add(
            SampleBatch({
                SampleBatch.T:
                5 * [np.random.random((4, ))],
                SampleBatch.ACTIONS:
                5 * [np.random.choice([0, 1])],
                SampleBatch.REWARDS:
                5 * [np.random.rand()],
                SampleBatch.DONES:
                5 * [np.random.choice([False, True])],
                SampleBatch.SEQ_LENS: [5],
                "batch_id":
                5 * [5],
            }))

        # After adding 1 more batch, eviction has started with 15
        # timesteps added in total
        assert len(buffer) == 5
        assert buffer._num_timesteps_added == sum(range(1, 6))
        assert buffer._num_timesteps_added_wrap == 5
        assert buffer._next_idx == 1
        assert buffer._eviction_started is True

        # The first batch should now not be sampled anymore, other batches
        # should be sampled as before
        num_sampled_dict = {_id: 0 for _id in range(2, 6)}
        num_samples = 200
        for i in range(num_samples):
            sample = buffer.sample(1)
            _id = sample["batch_id"][0]
            assert len(sample[SampleBatch.SEQ_LENS]) == 1
            num_sampled_dict[_id] += 1

        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            [1 / 5, 1 / 5, 2 / 5, 1 / 5],
            atol=0.1,
        )
Exemple #8
0
    def test_timesteps_unit(self):
        """Tests adding, sampling, get-/set state, and eviction with
        experiences stored by timesteps.
        """
        self.batch_id = 0

        def _add_data_to_buffer(_buffer, batch_size, num_batches=5, **kwargs):
            def _generate_data():
                return SampleBatch({
                    SampleBatch.T: [np.random.random((4, ))],
                    SampleBatch.ACTIONS: [np.random.choice([0, 1])],
                    SampleBatch.OBS: [np.random.random((4, ))],
                    SampleBatch.NEXT_OBS: [np.random.random((4, ))],
                    SampleBatch.REWARDS: [np.random.rand()],
                    SampleBatch.DONES: [np.random.choice([False, True])],
                    "batch_id": [self.batch_id],
                })

            for i in range(num_batches):
                data = [_generate_data() for _ in range(batch_size)]
                self.batch_id += 1
                batch = SampleBatch.concat_samples(data)
                _buffer.add(batch, **kwargs)

        batch_size = 5
        buffer_size = 15

        buffer = ReplayBuffer(capacity=buffer_size)

        # Test add/sample
        _add_data_to_buffer(buffer, batch_size=batch_size, num_batches=1)

        _add_data_to_buffer(buffer, batch_size=batch_size, num_batches=2)

        # Sampling from it now should yield our first batch 1/3 of the time
        num_sampled_dict = {_id: 0 for _id in range(self.batch_id)}
        num_samples = 200
        for i in range(num_samples):
            _id = buffer.sample(1)["batch_id"][0]
            num_sampled_dict[_id] += 1
        assert np.allclose(
            np.array(list(num_sampled_dict.values())) / num_samples,
            len(num_sampled_dict) * [1 / 3],
            atol=0.1,
        )

        # Test set/get state
        state = buffer.get_state()
        other_buffer = ReplayBuffer(capacity=buffer_size)
        _add_data_to_buffer(other_buffer, 1)
        other_buffer.set_state(state)

        assert other_buffer._storage == buffer._storage
        assert other_buffer._next_idx == buffer._next_idx
        assert other_buffer._num_timesteps_added == buffer._num_timesteps_added
        assert (other_buffer._num_timesteps_added_wrap ==
                buffer._num_timesteps_added_wrap)
        assert other_buffer._num_timesteps_sampled == buffer._num_timesteps_sampled
        assert other_buffer._eviction_started == buffer._eviction_started
        assert other_buffer._est_size_bytes == buffer._est_size_bytes
        assert len(other_buffer) == len(other_buffer)
    def __init__(self,
                 capacity: int = 10000,
                 storage_unit: str = "timesteps",
                 num_shards: int = 1,
                 replay_batch_size: int = 1,
                 learning_starts: int = 1000,
                 replay_mode: str = "independent",
                 replay_sequence_length: int = 1,
                 replay_burn_in: int = 0,
                 replay_zero_init_states: bool = True,
                 underlying_buffer_config: dict = None,
                 **kwargs):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            num_shards: The number of buffer shards that exist in total
                (including this one).
            storage_unit: Either 'timesteps', 'sequences' or
                'episodes'. Specifies how experiences are stored. If they
                are stored in episodes, replay_sequence_length is ignored.
            learning_starts: Number of timesteps after which a call to
                `replay()` will yield samples (before that, `replay()` will
                return None).
            capacity: Max number of total timesteps in all policy buffers.
                After reaching this number, older samples will be
                dropped to make space for new ones.
            replay_batch_size: The batch size to be sampled (in timesteps).
                Note that if `replay_sequence_length` > 1,
                `self.replay_batch_size` will be set to the number of
                sequences sampled (B).
            replay_mode: One of "independent" or "lockstep". Determines,
                whether batches are sampled independently or to an equal
                amount.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer. This
                only has an effect if storage_unit is 'timesteps'.
            replay_burn_in: The burn-in length in case
                `replay_sequence_length` > 0. This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout. This only has an effect
                if storage_unit is 'timesteps'.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
            underlying_buffer_config: A config that contains all necessary
                constructor arguments and arguments for methods to call on
                the underlying buffers.
            **kwargs: Forward compatibility kwargs.
        """
        shard_capacity = capacity // num_shards
        ReplayBuffer.__init__(self, capacity, storage_unit)

        # If the user provides an underlying buffer config, we use to
        # instantiate and interact with underlying buffers
        self.underlying_buffer_config = underlying_buffer_config
        if self.underlying_buffer_config is not None:
            self.underlying_buffer_call_args = self.underlying_buffer_config
        else:
            self.underlying_buffer_call_args = {}

        self.replay_batch_size = replay_batch_size
        self.replay_starts = learning_starts // num_shards
        self.replay_mode = replay_mode
        self.replay_sequence_length = replay_sequence_length
        self.replay_burn_in = replay_burn_in
        self.replay_zero_init_states = replay_zero_init_states

        if replay_mode in ["lockstep", ReplayMode.LOCKSTEP]:
            self.replay_mode = ReplayMode.LOCKSTEP
            if self._storage_unit in [
                    StorageUnit.EPISODES, StorageUnit.SEQUENCES
            ]:
                raise ValueError("MultiAgentReplayBuffer does not support "
                                 "lockstep mode with storage unit `episodes`"
                                 "or `sequences`.")
        elif replay_mode in ["independent", ReplayMode.INDEPENDENT]:
            self.replay_mode = ReplayMode.INDEPENDENT
        else:
            raise ValueError("Unsupported replay mode: {}".format(replay_mode))

        if self.underlying_buffer_config:
            ctor_args = {
                **{
                    "capacity": shard_capacity,
                    "storage_unit": storage_unit
                },
                **self.underlying_buffer_config,
            }

            def new_buffer():
                return from_config(self.underlying_buffer_config["type"],
                                   ctor_args)

        else:
            # Default case
            def new_buffer():
                self.underlying_buffer_call_args = {}
                return ReplayBuffer(
                    self.capacity,
                    storage_unit=storage_unit,
                )

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics.
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self._num_added = 0
 def new_buffer():
     self.underlying_buffer_call_args = {}
     return ReplayBuffer(
         self.capacity,
         storage_unit=storage_unit,
     )