def add_batch(self, batch): # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: if self.replay_mode == "lockstep": for s in batch.timeslices(self.replay_sequence_length): self.replay_buffers[_ALL_POLICIES].add(s) else: for policy_id, b in batch.policy_batches.items(): for s in b.timeslices(self.replay_sequence_length): self.replay_buffers[policy_id].add(s) self.num_added += batch.count
def add_batch(self, batch): # Make a copy so the replay buffer doesn't pin plasma memory. batch = batch.copy() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) with self.add_batch_timer: if self.replay_mode == "lockstep": # Note that prioritization is not supported in this mode. for s in batch.timeslices(self.replay_sequence_length): self.replay_buffers[_ALL_POLICIES].add(s, weight=None) else: for policy_id, b in batch.policy_batches.items(): for s in b.timeslices(self.replay_sequence_length): if "weights" in s: weight = np.mean(s["weights"]) else: weight = None self.replay_buffers[policy_id].add(s, weight=weight) self.num_added += batch.count
def test_timeslices_partially_overlapping_experiences(self): """Tests if timeslices works as expected on a MultiAgentBatch consisting of two partially overlapping SampleBatches. """ def _generate_data(agent_idx, t_start): batch = SampleBatch({ SampleBatch.T: [t_start, t_start + 1], SampleBatch.EPS_ID: [0, 0], SampleBatch.AGENT_INDEX: 2 * [agent_idx], SampleBatch.SEQ_LENS: [2], }) return batch policy_batches = { str(idx): _generate_data(idx, idx) for idx in (range(2)) } ma_batch = MultiAgentBatch(policy_batches, 4) sliced_ma_batches = ma_batch.timeslices(1) [ check_same_batch(i, j) for i, j in zip( sliced_ma_batches, [ MultiAgentBatch( { "0": SampleBatch({ SampleBatch.T: [0], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], }) }, 1, ), MultiAgentBatch( { "0": SampleBatch({ SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [0], SampleBatch.SEQ_LENS: [1], }), "1": SampleBatch({ SampleBatch.T: [1], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], }), }, 1, ), MultiAgentBatch( { "1": SampleBatch({ SampleBatch.T: [2], SampleBatch.EPS_ID: [0], SampleBatch.AGENT_INDEX: [1], SampleBatch.SEQ_LENS: [1], }) }, 1, ), ], ) ]