コード例 #1
0
    def test_sample_loaded_memories(self):
        with TempDir() as temp_directory:
            replay_memories_file_name = os.path.join(temp_directory, 'replay_memories.dat')
            replay_memories1 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memory1 = self._create_replay_memory(action_index=11)
            replay_memory2 = self._create_replay_memory(action_index=22)
            replay_memory3 = self._create_replay_memory(action_index=32)
            replay_memory4 = self._create_replay_memory(action_index=42)
            replay_memory5 = self._create_replay_memory(action_index=52)
            replay_memory6 = self._create_replay_memory(action_index=62)
            replay_memory7 = self._create_replay_memory(action_index=72)
            replay_memories1.append(replay_memory1)
            replay_memories1.append(replay_memory2)
            replay_memories1.append(replay_memory3)
            replay_memories1.append(replay_memory4)
            replay_memories1.append(replay_memory5)
            replay_memories1.append(replay_memory6)
            replay_memories1.append(replay_memory7)
            replay_memories1.save()

            replay_memories2 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            sampled_replay_memories = replay_memories2.sample(5, seed=3)

            expected_replay_memories = [replay_memory2, replay_memory4, replay_memory7, replay_memory3, replay_memory5]
            self.assertItemsEqual(sampled_replay_memories, expected_replay_memories)
コード例 #2
0
    def test_sample_recent_memories_after_appending_to_loaded_memories(self):
        with TempDir() as temp_directory:
            replay_memories_file_name = os.path.join(temp_directory, 'replay_memories.dat')
            replay_memories1 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memory1 = self._create_replay_memory(action_index=11)
            replay_memory2 = self._create_replay_memory(action_index=22)
            replay_memory3 = self._create_replay_memory(action_index=32)
            replay_memory4 = self._create_replay_memory(action_index=42)
            replay_memory5 = self._create_replay_memory(action_index=52)
            replay_memory6 = self._create_replay_memory(action_index=62)
            replay_memory7 = self._create_replay_memory(action_index=72)
            replay_memories1.append(replay_memory1)
            replay_memories1.append(replay_memory2)
            replay_memories1.append(replay_memory3)
            replay_memories1.save()

            replay_memories2 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memories2.append(replay_memory4)
            replay_memories2.append(replay_memory5)
            replay_memories2.append(replay_memory6)
            replay_memories2.append(replay_memory7)

            sampled_replay_memories = replay_memories2.sample(5, recent_memories_span=5)

            self.assertTrue(replay_memory3 in sampled_replay_memories)
            self.assertTrue(replay_memory4 in sampled_replay_memories)
            self.assertTrue(replay_memory5 in sampled_replay_memories)
            self.assertTrue(replay_memory6 in sampled_replay_memories)
            self.assertTrue(replay_memory7 in sampled_replay_memories)
コード例 #3
0
    def test_sample_recent_memories(self):
        with TempDir() as temp_directory:
            replay_memories_file_name = os.path.join(temp_directory, 'replay_memories.dat')
            replay_memories = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memory1 = self._create_replay_memory(action_index=11)
            replay_memory2 = self._create_replay_memory(action_index=22)
            replay_memory3 = self._create_replay_memory(action_index=32)
            replay_memory4 = self._create_replay_memory(action_index=42)
            replay_memory5 = self._create_replay_memory(action_index=52)
            replay_memory6 = self._create_replay_memory(action_index=62)
            replay_memory7 = self._create_replay_memory(action_index=72)
            replay_memories.append(replay_memory1)
            replay_memories.append(replay_memory2)
            replay_memories.append(replay_memory3)
            replay_memories.append(replay_memory4)
            replay_memories.append(replay_memory5)
            replay_memories.append(replay_memory6)
            replay_memories.append(replay_memory7)

            sampled_replay_memories = replay_memories.sample(2, recent_memories_span=2)

            print(sampled_replay_memories)

            self.assertTrue(replay_memory6 in sampled_replay_memories)
            self.assertTrue(replay_memory7 in sampled_replay_memories)
コード例 #4
0
    def test_sample_size_smaller_than_recent_memories(self):
        with TempDir() as temp_directory:
            replay_memories_file_name = os.path.join(temp_directory, 'replay_memories.dat')
            replay_memories1 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memory1 = self._create_replay_memory(action_index=11)
            replay_memory2 = self._create_replay_memory(action_index=22)
            replay_memory3 = self._create_replay_memory(action_index=32)
            replay_memory4 = self._create_replay_memory(action_index=42)
            replay_memories1.append(replay_memory1)
            replay_memories1.append(replay_memory2)
            replay_memories1.append(replay_memory3)
            replay_memories1.save()

            replay_memories2 = ReplayMemories(replay_memories_file_name, max_current_memories_in_ram=100)
            replay_memories2.append(replay_memory4)

            sampled_replay_memories = replay_memories2.sample(2, recent_memories_span=3, seed=4)

            expected_replay_memories = [replay_memory2, replay_memory4]
            self.assertItemsEqual(sampled_replay_memories, expected_replay_memories)