Exemplo n.º 1
0
 def test_manual(self):
     episodes = [
         [1, 2, 3],
         [4, 5],
         [6, 7, 8],
         [9],
         [10, 11, 12],
     ]
     self.assertEqual(
         list(
             ppo._yield_subset_of_sequences_with_fixed_number_of_items(episodes, 4)
         ),
         [[[1, 2, 3], [4]], [[5], [6, 7, 8]], [[9], [10, 11, 12]],],
     )
     self.assertEqual(
         list(
             ppo._yield_subset_of_sequences_with_fixed_number_of_items(episodes, 3)
         ),
         [[[1, 2, 3]], [[4, 5], [6]], [[7, 8], [9]], [[10, 11, 12]],],
     )
     self.assertEqual(
         list(
             ppo._yield_subset_of_sequences_with_fixed_number_of_items(episodes, 2)
         ),
         [[[1, 2]], [[3], [4]], [[5], [6]], [[7, 8]], [[9], [10]], [[11, 12]],],
     )
Exemplo n.º 2
0
    def _update_vf_recurrent(self, dataset):

        for epoch in range(self.vf_epochs):
            random.shuffle(dataset)
            for (
                    minibatch
            ) in _yield_subset_of_sequences_with_fixed_number_of_items(  # NOQA
                    dataset, self.vf_batch_size):
                self._update_vf_once_recurrent(minibatch)