Esempio n. 1
0
 def __init__(
     self,
     capacity=None,
     alpha=0.6,
     beta0=0.4,
     betasteps=2e5,
     eps=1e-8,
     normalize_by_max=True,
     default_priority_func=None,
     uniform_ratio=0,
     wait_priority_after_sampling=True,
     return_sample_weights=True,
     error_min=None,
     error_max=None,
 ):
     self.current_episode = []
     self.episodic_memory = PrioritizedBuffer(
         capacity=None,
         wait_priority_after_sampling=wait_priority_after_sampling)
     self.memory = RandomAccessQueue(maxlen=capacity)
     self.capacity_left = capacity
     self.default_priority_func = default_priority_func
     self.uniform_ratio = uniform_ratio
     self.return_sample_weights = return_sample_weights
     PriorityWeightError.__init__(self,
                                  alpha,
                                  beta0,
                                  betasteps,
                                  eps,
                                  normalize_by_max,
                                  error_min=error_min,
                                  error_max=error_max)
Esempio n. 2
0
 def setUp(self):
     if self.init_seq:
         self.y_queue = RandomAccessQueue(self.init_seq, maxlen=self.maxlen)
         self.t_queue = collections.deque(self.init_seq, maxlen=self.maxlen)
     else:
         self.y_queue = RandomAccessQueue(maxlen=self.maxlen)
         self.t_queue = collections.deque(maxlen=self.maxlen)
Esempio n. 3
0
class ReplayBuffer(AbstractReplayBuffer):

    def __init__(self, capacity=None):
        self.memory = RandomAccessQueue(maxlen=capacity)

    def append(self, state, action, reward, next_state=None, next_action=None,
               is_state_terminal=False):
        experience = dict(state=state, action=action, reward=reward,
                          next_state=next_state, next_action=next_action,
                          is_state_terminal=is_state_terminal)
        self.memory.append(experience)

    def sample(self, n):
        assert len(self.memory) >= n
        return self.memory.sample(n)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            self.memory = pickle.load(f)
        if isinstance(self.memory, collections.deque):
            # Load v0.2
            self.memory = RandomAccessQueue(
                self.memory, maxlen=self.memory.maxlen)

    def stop_current_episode(self):
        pass
Esempio n. 4
0
 def load(self, filename):
     with open(filename, 'rb') as f:
         self.memory = pickle.load(f)
     if isinstance(self.memory, collections.deque):
         # Load v0.2
         self.memory = RandomAccessQueue(
             self.memory, maxlen=self.memory.maxlen)
 def __init__(self, capacity=None, num_steps=1):
     self.capacity = capacity
     assert num_steps > 0
     self.num_steps = num_steps
     self.memory = RandomAccessQueue(maxlen=capacity)
     self.last_n_transitions = collections.defaultdict(
         lambda: collections.deque([], maxlen=num_steps))
Esempio n. 6
0
class PrioritizedEpisodicReplayBuffer (
        EpisodicReplayBuffer, PriorityWeightError):

    def __init__(self, capacity=None,
                 alpha=0.6, beta0=0.4, betasteps=2e5, eps=1e-8,
                 normalize_by_max=True,
                 default_priority_func=None,
                 uniform_ratio=0,
                 wait_priority_after_sampling=True,
                 return_sample_weights=True,
                 error_min=None,
                 error_max=None,
                 ):
        self.current_episode = []
        self.episodic_memory = PrioritizedBuffer(
            capacity=None,
            wait_priority_after_sampling=wait_priority_after_sampling)
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.capacity_left = capacity
        self.default_priority_func = default_priority_func
        self.uniform_ratio = uniform_ratio
        self.return_sample_weights = return_sample_weights
        PriorityWeightError.__init__(
            self, alpha, beta0, betasteps, eps, normalize_by_max,
            error_min=error_min, error_max=error_max)

    def sample_episodes(self, n_episodes, max_len=None):
        """Sample n unique samples from this replay buffer"""
        assert len(self.episodic_memory) >= n_episodes
        episodes, probabilities, min_prob = self.episodic_memory.sample(
            n_episodes, uniform_ratio=self.uniform_ratio)
        if max_len is not None:
            episodes = [random_subseq(ep, max_len) for ep in episodes]
        if self.return_sample_weights:
            weights = self.weights_from_probabilities(probabilities, min_prob)
            return episodes, weights
        else:
            return episodes

    def update_errors(self, errors):
        self.episodic_memory.set_last_priority(
            self.priority_from_errors(errors))

    def stop_current_episode(self):
        if self.current_episode:
            if self.default_priority_func is not None:
                priority = self.default_priority_func(self.current_episode)
            else:
                priority = None
            self.memory.extend(self.current_episode)
            self.episodic_memory.append(self.current_episode,
                                        priority=priority)
            if self.capacity_left is not None:
                self.capacity_left -= len(self.current_episode)
            self.current_episode = []
            while self.capacity_left is not None and self.capacity_left < 0:
                discarded_episode = self.episodic_memory.popleft()
                self.capacity_left += len(discarded_episode)
        assert not self.current_episode
Esempio n. 7
0
    def stop_current_episode(self):
        for ac in self.action_base_experience.keys():
            self.action_memory[ac] = RandomAccessQueue()
            self.action_memory[ac].extend(self.action_base_experience[ac])

        if self.current_episode:
            self.current_episode = []
    def __init__(self, capacity = 2000, lookup_k = 5, n_action = None,
                 key_size = 256, xp = np):
        
        self.capacity = capacity
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.lookup_k = lookup_k
        self.xp = xp
        self.num_action = n_action
        self.key_size = key_size
        assert self.num_action

        self.tmp_emb_arr = self.xp.empty((0, self.key_size),
                                     dtype='float32')

        self.knn = knn.ArgsortKnn(capacity = self.capacity,
                                  dimension=key_size, xp = self.xp)
Esempio n. 9
0
class ReplayBuffer(object):
    def __init__(self, capacity=None):
        self.memory = RandomAccessQueue(maxlen=capacity)

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False):
        """Append a transition to this replay buffer

        Args:
            state: s_t
            action: a_t
            reward: r_t
            next_state: s_{t+1} (can be None if terminal)
            next_action: a_{t+1} (can be None for off-policy algorithms)
            is_state_terminal (bool)
        """
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal)
        self.memory.append(experience)

    def sample(self, n):
        """Sample n unique samples from this replay buffer"""
        assert len(self.memory) >= n
        return self.memory.sample(n)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            self.memory = pickle.load(f)

    def stop_current_episode(self):
        pass
Esempio n. 10
0
    def load(self, filename):
        with open(filename, 'rb') as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.memory, self.episodic_memory = memory
        else:
            # Load v0.2
            # FIXME: The code works with EpisodicReplayBuffer
            # but not with PrioritizedEpisodicReplayBuffer
            self.memory = RandomAccessQueue(memory)
            self.episodic_memory = RandomAccessQueue()

            # Recover episodic_memory with best effort.
            episode = []
            for item in self.memory:
                episode.append(item)
                if item['is_state_terminal']:
                    self.episodic_memory.append(episode)
                    episode = []
Esempio n. 11
0
    def stop_current_episode(self):
        if self.current_episode:
            new_normal_episode = None
            if len(self.current_episode) > 1:
                if len(self.good_episodic_memory
                       ) >= self.good_episodic_memory_capacity:
                    new_normal_episode = heapq.heappushpop(
                        self.good_episodic_memory,
                        (copy.copy(self.current_episode_R),
                         copy.copy(self.episode_count), self.current_episode))
                else:
                    heapq.heappush(
                        self.good_episodic_memory,
                        (copy.copy(self.current_episode_R),
                         copy.copy(self.episode_count), self.current_episode))

            self.current_episode = []
            self.episode_count += 1

            new_bad_episode = None
            if new_normal_episode is not None:
                if len(self.normal_episodic_memory
                       ) >= self.normal_episodic_memory_capacity:
                    new_bad_episode = heapq.heappushpop(
                        self.normal_episodic_memory, new_normal_episode)
                else:
                    heapq.heappush(self.normal_episodic_memory,
                                   new_normal_episode)

            if new_bad_episode is not None:
                if len(self.bad_episodic_memory
                       ) >= self.bad_episodic_memory_capacity:
                    drop_episode = heapq.heappushpop(self.bad_episodic_memory,
                                                     new_bad_episode)
                    self.all_step_count -= len(drop_episode[2])
                else:
                    heapq.heappush(self.bad_episodic_memory, new_bad_episode)

            self.good_memory = RandomAccessQueue()
            for e in self.good_episodic_memory:
                self.good_memory.extend(e[2])

            self.normal_memory = RandomAccessQueue()
            for e in self.normal_episodic_memory:
                self.normal_memory.extend(e[2])

            self.bad_memory = RandomAccessQueue()
            for e in self.bad_episodic_memory:
                self.bad_memory.extend(e[2])

        assert not self.current_episode

        self.current_episode_R = 0.0
Esempio n. 12
0
    def __init__(self, capacity=None):
        self.current_episode = []
        self.current_episode_R = 0.0

        self.good_episodic_memory = []
        self.good_episodic_memory_capacity = 20
        self.good_memory = RandomAccessQueue()

        self.normal_episodic_memory = []
        self.normal_episodic_memory_capacity = 50
        self.normal_memory = RandomAccessQueue()

        self.bad_episodic_memory = []
        self.bad_episodic_memory_capacity = 10
        self.bad_memory = RandomAccessQueue()

        self.capacity = capacity
        self.all_step_count = 0
        self.episode_count = 0
Esempio n. 13
0
    def load(self, filename):
        with open(filename, 'rb') as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.good_episodic_memory, self.normal_episodic_memory, self.bad_episodic_memory, self.all_step_count, self.episode_count = memory

            self.good_memory = RandomAccessQueue()
            for e in self.good_episodic_memory:
                self.good_memory.extend(e[2])

            self.normal_memory = RandomAccessQueue()
            for e in self.normal_episodic_memory:
                self.normal_memory.extend(e[2])

            self.bad_memory = RandomAccessQueue()
            for e in self.bad_episodic_memory:
                self.bad_memory.extend(e[2])

            self.current_episode = []
            self.current_episode_R = 0.0
        else:
            print("bad replay file")
Esempio n. 14
0
 def __init__(self, capacity=None):
     self.current_episode = []
     self.episodic_memory = RandomAccessQueue()
     self.memory = RandomAccessQueue()
     self.capacity = capacity
Esempio n. 15
0
class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer):

    def __init__(self, capacity=None):
        self.current_episode = []
        self.episodic_memory = RandomAccessQueue()
        self.memory = RandomAccessQueue()
        self.capacity = capacity

    def append(self, state, action, reward, next_state=None, next_action=None,
               is_state_terminal=False, **kwargs):
        experience = dict(state=state, action=action, reward=reward,
                          next_state=next_state, next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        self.current_episode.append(experience)
        if is_state_terminal:
            self.stop_current_episode()

    def sample(self, n):
        assert len(self.memory) >= n
        return self.memory.sample(n)

    def sample_episodes(self, n_episodes, max_len=None):
        assert len(self.episodic_memory) >= n_episodes
        episodes = self.episodic_memory.sample(n_episodes)
        if max_len is not None:
            return [random_subseq(ep, max_len) for ep in episodes]
        else:
            return episodes

    def __len__(self):
        return len(self.memory)

    @property
    def n_episodes(self):
        return len(self.episodic_memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump((self.memory, self.episodic_memory), f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.memory, self.episodic_memory = memory
        else:
            # Load v0.2
            # FIXME: The code works with EpisodicReplayBuffer
            # but not with PrioritizedEpisodicReplayBuffer
            self.memory = RandomAccessQueue(memory)
            self.episodic_memory = RandomAccessQueue()

            # Recover episodic_memory with best effort.
            episode = []
            for item in self.memory:
                episode.append(item)
                if item['is_state_terminal']:
                    self.episodic_memory.append(episode)
                    episode = []

    def stop_current_episode(self):
        if self.current_episode:
            self.episodic_memory.append(self.current_episode)
            self.memory.extend(self.current_episode)
            self.current_episode = []
            while self.capacity is not None and \
                    len(self.memory) > self.capacity:
                discarded_episode = self.episodic_memory.popleft()
                for _ in range(len(discarded_episode)):
                    self.memory.popleft()
        assert not self.current_episode
Esempio n. 16
0
 def __init__(self, capacity=None):
     self.memory = RandomAccessQueue(maxlen=capacity)
Esempio n. 17
0
 def __init__(self, capacity=None):
     self.current_episode = collections.defaultdict(list)
     self.episodic_memory = RandomAccessQueue()
     self.memory = RandomAccessQueue()
     self.capacity = capacity
Esempio n. 18
0
class EpisodicReplayBuffer(object):
    def __init__(self, capacity=None):
        self.current_episode = []
        self.episodic_memory = RandomAccessQueue()
        self.memory = RandomAccessQueue()
        self.capacity = capacity

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               **kwargs):
        """Append a transition to this replay buffer

        Args:
            state: s_t
            action: a_t
            reward: r_t
            next_state: s_{t+1} (can be None if terminal)
            next_action: a_{t+1} (can be None for off-policy algorithms)
            is_state_terminal (bool)
        """
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        self.current_episode.append(experience)
        if is_state_terminal:
            self.stop_current_episode()

    def sample(self, n):
        """Sample n unique samples from this replay buffer"""
        assert len(self.memory) >= n
        return self.memory.sample(n)

    def sample_episodes(self, n_episodes, max_len=None):
        """Sample n unique samples from this replay buffer"""
        assert len(self.episodic_memory) >= n_episodes
        episodes = self.episodic_memory.sample(n_episodes)
        if max_len is not None:
            return [random_subseq(ep, max_len) for ep in episodes]
        else:
            return episodes

    def __len__(self):
        return len(self.episodic_memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump((self.memory, self.episodic_memory), f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            self.memory, self.episodic_memory = pickle.load(f)

    def stop_current_episode(self):
        if self.current_episode:
            self.episodic_memory.append(self.current_episode)
            self.memory.extend(self.current_episode)
            self.current_episode = []
            while self.capacity is not None and \
                    len(self.memory) > self.capacity:
                discarded_episode = self.episodic_memory.popleft()
                for _ in range(len(discarded_episode)):
                    self.memory.popleft()
        assert not self.current_episode
Esempio n. 19
0
class TestRandomAccessQueue(unittest.TestCase):
    def setUp(self):
        if self.init_seq:
            self.y_queue = RandomAccessQueue(self.init_seq, maxlen=self.maxlen)
            self.t_queue = collections.deque(self.init_seq, maxlen=self.maxlen)
        else:
            self.y_queue = RandomAccessQueue(maxlen=self.maxlen)
            self.t_queue = collections.deque(maxlen=self.maxlen)

    def test1(self):
        self.check_all()

        self.check_popleft()
        self.do_append(10)
        self.check_all()

        self.check_popleft()
        self.check_popleft()
        self.do_append(11)
        self.check_all()

        # test negative indices
        n = len(self.t_queue)
        for i in range(-n, 0):
            self.check_getitem(i)

        for k in range(4):
            self.do_extend(range(k))
            self.check_all()

        for k in range(4):
            self.check_popleft()
            self.do_extend(range(k))
            self.check_all()

        for k in range(10):
            self.do_append(20 + k)
            self.check_popleft()
            self.check_popleft()
            self.check_all()

        for _ in range(100):
            self.check_popleft()

    def check_all(self):
        self.check_len()
        n = len(self.t_queue)
        for i in range(n):
            self.check_getitem(i)

    def check_len(self):
        self.assertEqual(len(self.y_queue), len(self.t_queue))

    def check_getitem(self, i):
        self.assertEqual(self.y_queue[i], self.t_queue[i])

    def do_setitem(self, i, x):
        self.y_queue[i] = x
        self.t_queue[i] = x

    def do_append(self, x):
        self.y_queue.append(x)
        self.t_queue.append(x)

    def do_extend(self, xs):
        self.y_queue.extend(xs)
        self.t_queue.extend(xs)

    def check_popleft(self):
        try:
            t = self.t_queue.popleft()
        except IndexError:
            with self.assertRaises(IndexError):
                self.y_queue.popleft()
        else:
            self.assertEqual(self.y_queue.popleft(), t)
Esempio n. 20
0
class SuccessPrioReplayBuffer(chainerrl.replay_buffer.AbstractReplayBuffer):
    def __init__(self, capacity=None):
        self.current_episode = []
        self.current_episode_R = 0.0

        self.good_episodic_memory = []
        self.good_episodic_memory_capacity = 20
        self.good_memory = RandomAccessQueue()

        self.normal_episodic_memory = []
        self.normal_episodic_memory_capacity = 50
        self.normal_memory = RandomAccessQueue()

        self.bad_episodic_memory = []
        self.bad_episodic_memory_capacity = 10
        self.bad_memory = RandomAccessQueue()

        self.capacity = capacity
        self.all_step_count = 0
        self.episode_count = 0

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               **kwargs):
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        self.current_episode.append(experience)

        self.current_episode_R += reward
        self.all_step_count += 1

        if is_state_terminal:
            self.stop_current_episode()

    def sample(self, n):
        count_sample = 0
        ans = []
        if len(self.bad_memory) > 0:
            n_s = min((len(self.bad_memory), n // 4))
            ans.extend(self.bad_memory.sample(n_s))
            count_sample += n_s

        if len(self.normal_memory) > 0:
            n_s = min((len(self.normal_memory), (n // 4) * 2 - count_sample))
            ans.extend(self.normal_memory.sample(n_s))
            count_sample += n_s

        if len(self.good_memory) > 0:
            n_s = min((len(self.good_memory), (n // 4) * 3 - count_sample))
            ans.extend(self.good_memory.sample(n_s))
            count_sample += n_s

        if (count_sample < n) and (len(self.current_episode) > 0):
            n_s = min((len(self.current_episode), n - count_sample))
            #ans.extend(random.sample(self.current_episode, n_s))
            ans.extend(self.current_episode[len(self.current_episode) - 1 -
                                            n_s:len(self.current_episode) - 1])

        return ans

    def __len__(self):
        return self.all_step_count

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump((self.good_episodic_memory,
                         self.normal_episodic_memory, self.bad_episodic_memory,
                         self.all_step_count, self.episode_count), f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.good_episodic_memory, self.normal_episodic_memory, self.bad_episodic_memory, self.all_step_count, self.episode_count = memory

            self.good_memory = RandomAccessQueue()
            for e in self.good_episodic_memory:
                self.good_memory.extend(e[2])

            self.normal_memory = RandomAccessQueue()
            for e in self.normal_episodic_memory:
                self.normal_memory.extend(e[2])

            self.bad_memory = RandomAccessQueue()
            for e in self.bad_episodic_memory:
                self.bad_memory.extend(e[2])

            self.current_episode = []
            self.current_episode_R = 0.0
        else:
            print("bad replay file")

    def stop_current_episode(self):
        if self.current_episode:
            new_normal_episode = None
            if len(self.current_episode) > 1:
                if len(self.good_episodic_memory
                       ) >= self.good_episodic_memory_capacity:
                    new_normal_episode = heapq.heappushpop(
                        self.good_episodic_memory,
                        (copy.copy(self.current_episode_R),
                         copy.copy(self.episode_count), self.current_episode))
                else:
                    heapq.heappush(
                        self.good_episodic_memory,
                        (copy.copy(self.current_episode_R),
                         copy.copy(self.episode_count), self.current_episode))

            self.current_episode = []
            self.episode_count += 1

            new_bad_episode = None
            if new_normal_episode is not None:
                if len(self.normal_episodic_memory
                       ) >= self.normal_episodic_memory_capacity:
                    new_bad_episode = heapq.heappushpop(
                        self.normal_episodic_memory, new_normal_episode)
                else:
                    heapq.heappush(self.normal_episodic_memory,
                                   new_normal_episode)

            if new_bad_episode is not None:
                if len(self.bad_episodic_memory
                       ) >= self.bad_episodic_memory_capacity:
                    drop_episode = heapq.heappushpop(self.bad_episodic_memory,
                                                     new_bad_episode)
                    self.all_step_count -= len(drop_episode[2])
                else:
                    heapq.heappush(self.bad_episodic_memory, new_bad_episode)

            self.good_memory = RandomAccessQueue()
            for e in self.good_episodic_memory:
                self.good_memory.extend(e[2])

            self.normal_memory = RandomAccessQueue()
            for e in self.normal_episodic_memory:
                self.normal_memory.extend(e[2])

            self.bad_memory = RandomAccessQueue()
            for e in self.bad_episodic_memory:
                self.bad_memory.extend(e[2])

        assert not self.current_episode

        self.current_episode_R = 0.0
class ValueBuffer(with_metaclass(ABCMeta, object)):
    """non-parametricQ値を出力するためのbuffer"""

    def __init__(self, capacity = 2000, lookup_k = 5, n_action = None,
                 key_size = 256, xp = np):
        
        self.capacity = capacity
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.lookup_k = lookup_k
        self.xp = xp
        self.num_action = n_action
        self.key_size = key_size
        assert self.num_action

        self.tmp_emb_arr = self.xp.empty((0, self.key_size),
                                     dtype='float32')

        self.knn = knn.ArgsortKnn(capacity = self.capacity,
                                  dimension=key_size, xp = self.xp)

    def __len__(self):
        return len(self.memory)

    def store(self, embedding, q_np):

        # value bufferに保存する
        self._store(dict(embedding = embedding, action_value = q_np))
        #knnにembeddingを送る
        self.knn.add(embedding)

        assert len(self.knn) == len(self.memory)
        assert self.memory[0]['embedding'][0,0] == self.knn.head_emb()
        if len(self.memory) == self.capacity:
            assert self.memory[-1]['embedding'][-1,0] == self.knn.end_emb()

        # 戻り値はなし (必要ならつける)
        return


    def _store(self, dictionaries):
        # 蓄える(容量いっぱいのときなどの処理は場合分け)
        self.memory.append(dictionaries)
        while self.capacity is not None and \
            len(self.memory) > self.capacity:
            self.memory.popleft()


    def compute_q(self, embedding):

        """
        if len(self.memory) < self.lookup_k:
            k = len(self.memory)
        else:
            k = self.lookup_k
        """

        index_list = self.knn.search(embedding, self.lookup_k)

        tmp_vbuf = self.xp.asarray([self.memory[i]['action_value'] for i in index_list], dtype=self.xp.float32)

        q_np = self.xp.average(tmp_vbuf, axis=0)

        return q_np
Esempio n. 22
0
class ReplayBuffer(AbstractReplayBuffer):
    def __init__(self, capacity=None, num_steps=1):
        self.capacity = capacity
        assert num_steps > 0
        self.num_steps = num_steps
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.last_n_transitions = collections.deque([], maxlen=num_steps)

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False):
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal)
        self.last_n_transitions.append(experience)
        if is_state_terminal:
            while self.last_n_transitions:
                self.memory.append(list(self.last_n_transitions))
                del self.last_n_transitions[0]
            assert len(self.last_n_transitions) == 0
        else:
            if len(self.last_n_transitions) == self.num_steps:
                self.memory.append(list(self.last_n_transitions))

    def stop_current_episode(self):
        # if n-step transition hist is not full, add transition;
        # if n-step hist is indeed full, transition has already been added;
        if 0 < len(self.last_n_transitions) < self.num_steps:
            self.memory.append(list(self.last_n_transitions))
        # avoid duplicate entry
        if 0 < len(self.last_n_transitions) <= self.num_steps:
            del self.last_n_transitions[0]
        while self.last_n_transitions:
            self.memory.append(list(self.last_n_transitions))
            del self.last_n_transitions[0]
        assert len(self.last_n_transitions) == 0

    def sample(self, num_experiences):
        assert len(self.memory) >= num_experiences
        return self.memory.sample(num_experiences)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            self.memory = pickle.load(f)
        if isinstance(self.memory, collections.deque):
            # Load v0.2
            self.memory = RandomAccessQueue(self.memory,
                                            maxlen=self.memory.maxlen)
class ReplayBuffer(replay_buffer.AbstractReplayBuffer):
    """Experience Replay Buffer

    As described in
    https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf.

    Args:
        capacity (int): capacity in terms of number of transitions
        num_steps (int): Number of timesteps per stored transition
            (for N-step updates)
    """
    def __init__(self, capacity=None, num_steps=1):
        self.capacity = capacity
        assert num_steps > 0
        self.num_steps = num_steps
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.last_n_transitions = collections.defaultdict(
            lambda: collections.deque([], maxlen=num_steps))

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               env_id=0,
               **kwargs):
        last_n_transitions = self.last_n_transitions[env_id]
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        last_n_transitions.append(experience)
        if is_state_terminal:
            while last_n_transitions:
                self.memory.append(list(last_n_transitions))
                del last_n_transitions[0]
            assert len(last_n_transitions) == 0
        else:
            if len(last_n_transitions) == self.num_steps:
                self.memory.append(list(last_n_transitions))

    def stop_current_episode(self, env_id=0):
        last_n_transitions = self.last_n_transitions[env_id]
        # if n-step transition hist is not full, add transition;
        # if n-step hist is indeed full, transition has already been added;
        if 0 < len(last_n_transitions) < self.num_steps:
            self.memory.append(list(last_n_transitions))
        # avoid duplicate entry
        if 0 < len(last_n_transitions) <= self.num_steps:
            del last_n_transitions[0]
        while last_n_transitions:
            self.memory.append(list(last_n_transitions))
            del last_n_transitions[0]
        assert len(last_n_transitions) == 0

    def sample(self, num_experiences):
        assert len(self.memory) >= num_experiences
        return self.memory.sample(num_experiences)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, 'rb') as f:
            self.memory = pickle.load(f)
        if isinstance(self.memory, collections.deque):
            # Load v0.2
            self.memory = RandomAccessQueue(self.memory,
                                            maxlen=self.memory.maxlen)