Ejemplo n.º 1
0
 def add_batch(self, batch):
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         if self.replay_mode == "lockstep":
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s)
         else:
             for policy_id, b in batch.policy_batches.items():
                 for s in b.timeslices(self.replay_sequence_length):
                     self.replay_buffers[policy_id].add(s)
     self.num_added += batch.count
Ejemplo n.º 2
0
 def add_batch(self, batch):
     # Make a copy so the replay buffer doesn't pin plasma memory.
     batch = batch.copy()
     # Handle everything as if multiagent
     if isinstance(batch, SampleBatch):
         batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count)
     with self.add_batch_timer:
         if self.replay_mode == "lockstep":
             # Note that prioritization is not supported in this mode.
             for s in batch.timeslices(self.replay_sequence_length):
                 self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
         else:
             for policy_id, b in batch.policy_batches.items():
                 for s in b.timeslices(self.replay_sequence_length):
                     if "weights" in s:
                         weight = np.mean(s["weights"])
                     else:
                         weight = None
                     self.replay_buffers[policy_id].add(s, weight=weight)
     self.num_added += batch.count
Ejemplo n.º 3
0
    def test_timeslices_partially_overlapping_experiences(self):
        """Tests if timeslices works as expected on a MultiAgentBatch
        consisting of two partially overlapping SampleBatches.
        """
        def _generate_data(agent_idx, t_start):
            batch = SampleBatch({
                SampleBatch.T: [t_start, t_start + 1],
                SampleBatch.EPS_ID: [0, 0],
                SampleBatch.AGENT_INDEX: 2 * [agent_idx],
                SampleBatch.SEQ_LENS: [2],
            })
            return batch

        policy_batches = {
            str(idx): _generate_data(idx, idx)
            for idx in (range(2))
        }
        ma_batch = MultiAgentBatch(policy_batches, 4)
        sliced_ma_batches = ma_batch.timeslices(1)

        [
            check_same_batch(i, j) for i, j in zip(
                sliced_ma_batches,
                [
                    MultiAgentBatch(
                        {
                            "0":
                            SampleBatch({
                                SampleBatch.T: [0],
                                SampleBatch.EPS_ID: [0],
                                SampleBatch.AGENT_INDEX: [0],
                                SampleBatch.SEQ_LENS: [1],
                            })
                        },
                        1,
                    ),
                    MultiAgentBatch(
                        {
                            "0":
                            SampleBatch({
                                SampleBatch.T: [1],
                                SampleBatch.EPS_ID: [0],
                                SampleBatch.AGENT_INDEX: [0],
                                SampleBatch.SEQ_LENS: [1],
                            }),
                            "1":
                            SampleBatch({
                                SampleBatch.T: [1],
                                SampleBatch.EPS_ID: [0],
                                SampleBatch.AGENT_INDEX: [1],
                                SampleBatch.SEQ_LENS: [1],
                            }),
                        },
                        1,
                    ),
                    MultiAgentBatch(
                        {
                            "1":
                            SampleBatch({
                                SampleBatch.T: [2],
                                SampleBatch.EPS_ID: [0],
                                SampleBatch.AGENT_INDEX: [1],
                                SampleBatch.SEQ_LENS: [1],
                            })
                        },
                        1,
                    ),
                ],
            )
        ]