def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        self.buffer = MultiStepBuffer(self.replay_size, self.n_steps)
        self.populate(self.warm_start_size)

        self.dataset = ExperienceSourceDataset(self.train_batch)
        return DataLoader(dataset=self.dataset, batch_size=self.batch_size)
Example #2
0
    def dataloader(self, batch_iterator, pre_steps):
        self.dataset = ExperienceSourceDataset(batch_iterator)
        self.memory = Memory(500000)
        for i in range(pre_steps):
            self.play_step(self.epsilon)

        return DataLoader(dataset=self.dataset,
                          batch_size=32 * self.steps_to_train)
    def test_iterator(self):
        source = ExperienceSourceDataset(self.train_batch)
        batch_size = 10
        data_loader = DataLoader(source, batch_size=batch_size)

        for idx, batch in enumerate(data_loader):
            self.assertEqual(len(batch), batch_size)
            self.assertEqual(batch[0], 0)
            self.assertEqual(batch[5], 5)
            break
    def test_iterator(self):
        """Tests that the iterator returns batches correctly."""
        source = ExperienceSourceDataset(self.train_batch)
        batch_size = 10
        data_loader = DataLoader(source, batch_size=batch_size)

        for idx, batch in enumerate(data_loader):
            self.assertEqual(len(batch), batch_size)
            self.assertEqual(batch[0], 0)
            self.assertEqual(batch[5], 5)
            break
 def _dataloader(self) -> DataLoader:
     """Initialize the Replay Buffer dataset used for retrieving experiences"""
     dataset = ExperienceSourceDataset(self.train_batch)
     dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size)
     return dataloader