コード例 #1
0
ファイル: agent.py プロジェクト: wangyang59/flare
class OnlineHelper(AgentHelper):
    """
    Online helper. It calls `learn()` every `sample_interval`
    steps.

    While waiting for learning return, the calling `Agent` is blocked.
    """
    def __init__(self, name, communicator, sample_interval=5):
        super(OnlineHelper, self).__init__(name, communicator, sample_interval)
        # NoReplacementQueue used to store past experience.
        self.exp_queue = NoReplacementQueue()

    @staticmethod
    def exp_replay():
        return False

    def add_experience(self, e):
        self.exp_queue.add(e)

    def sample_experiences(self):
        return self.exp_queue.sample()
コード例 #2
0
ファイル: agent.py プロジェクト: wangyang59/flare
 def __init__(self, name, communicator, sample_interval=5):
     super(OnlineHelper, self).__init__(name, communicator, sample_interval)
     # NoReplacementQueue used to store past experience.
     self.exp_queue = NoReplacementQueue()
コード例 #3
0
ファイル: test_replay_buffer.py プロジェクト: ziyuli/flare
 def test_sampling(self):
     exp_q = NoReplacementQueue()
     #          obs           r    a    e
     exp_q.add((np.zeros(10), [1], [1], [0]))
     exp_q.add((np.zeros(10), [0], [-1], [1]))  # 1st episode end
     exp_q.add((np.zeros(10), [1], [2], [0]))
     exp_q.add((np.zeros(10), [1], [3], [0]))
     exp_q.add((np.zeros(10), [1], [4], [0]))
     exp_seqs = exp_q.sample(self.is_episode_end)
     self.assertEqual(len(exp_q), 1)
     self.assertEqual(len(exp_seqs), 2)
     self.assertEqual(len(exp_seqs[0]), 2)
     self.assertEqual(exp_seqs[0][0][2], [1])
     self.assertEqual(exp_seqs[0][1][2], [-1])
     self.assertEqual(len(exp_seqs[1]), 3)
     self.assertEqual(exp_seqs[1][0][2], [2])
     self.assertEqual(exp_seqs[1][1][2], [3])
     self.assertEqual(exp_seqs[1][2][2], [4])
     #          obs           r    a    e
     exp_q.add((np.zeros(10), [0], [-2], [1]))
     exp_seqs = exp_q.sample(self.is_episode_end)
     self.assertEqual(len(exp_q), 0)
     self.assertEqual(len(exp_seqs), 1)
     self.assertEqual(len(exp_seqs[0]), 2)
     self.assertEqual(exp_seqs[0][0][2], [4])
     self.assertEqual(exp_seqs[0][1][2], [-2])
     self.assertEqual(len(exp_q), 0)