def __init__(
     self,
     capacity=None,
     alpha=0.6,
     beta0=0.4,
     betasteps=2e5,
     eps=1e-8,
     normalize_by_max=True,
     default_priority_func=None,
     uniform_ratio=0,
     wait_priority_after_sampling=True,
     return_sample_weights=True,
     error_min=None,
     error_max=None,
 ):
     self.current_episode = collections.defaultdict(list)
     self.episodic_memory = PrioritizedBuffer(
         capacity=None, wait_priority_after_sampling=wait_priority_after_sampling
     )
     self.memory = RandomAccessQueue(maxlen=capacity)
     self.capacity_left = capacity
     self.default_priority_func = default_priority_func
     self.uniform_ratio = uniform_ratio
     self.return_sample_weights = return_sample_weights
     PriorityWeightError.__init__(
         self,
         alpha,
         beta0,
         betasteps,
         eps,
         normalize_by_max,
         error_min=error_min,
         error_max=error_max,
     )
Example #2
0
 def load(self, filename):
     with open(filename, "rb") as f:
         self.memory = pickle.load(f)
     if isinstance(self.memory, collections.deque):
         # Load v0.2
         self.memory = RandomAccessQueue(self.memory,
                                         maxlen=self.memory.maxlen)
Example #3
0
 def __init__(self, capacity: Optional[int] = None, num_steps: int = 1):
     self.capacity = capacity
     assert num_steps > 0
     self.num_steps = num_steps
     self.memory = RandomAccessQueue(maxlen=capacity)
     self.last_n_transitions: collections.defaultdict = collections.defaultdict(
         lambda: collections.deque([], maxlen=num_steps))
Example #4
0
 def setUp(self, maxlen, init_seq):
     self.maxlen = maxlen
     self.init_seq = init_seq
     if self.init_seq:
         self.y_queue = RandomAccessQueue(self.init_seq, maxlen=self.maxlen)
         self.t_queue = collections.deque(self.init_seq, maxlen=self.maxlen)
     else:
         self.y_queue = RandomAccessQueue(maxlen=self.maxlen)
         self.t_queue = collections.deque(maxlen=self.maxlen)
Example #5
0
 def __init__(self, capacity, n_dim=256, n_neighbors=5, num_steps=1):
     self.capacity = capacity
     assert num_steps > 0
     self.num_steps = num_steps
     self.memory = RandomAccessQueue(maxlen=capacity)
     self.h_memory = lkb(capacity=capacity, n_dim=n_dim)
     self.device = torch.device(
         "cuda:0" if torch.cuda.is_available() else "cpu")
     # self.current_embeddings = torch.empty(0, n_dim, device=self.device, dtype=torch.float32)
     # self.current_embeddings = []
     self.last_n_transitions = collections.defaultdict(
         lambda: collections.deque([], maxlen=num_steps))
     self.n_neighbors = n_neighbors
Example #6
0
    def load(self, filename):
        with open(filename, "rb") as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.memory, self.episodic_memory = memory
        else:
            # Load v0.2
            # FIXME: The code works with EpisodicReplayBuffer
            # but not with PrioritizedEpisodicReplayBuffer
            self.memory = RandomAccessQueue(memory)
            self.episodic_memory = RandomAccessQueue()

            # Recover episodic_memory with best effort.
            episode = []
            for item in self.memory:
                episode.append(item)
                if item["is_state_terminal"]:
                    self.episodic_memory.append(episode)
                    episode = []
    def __init__(self, basedir, maxlen, *, ancestor=None, logger=None):
        assert maxlen is None or maxlen > 0
        self.basedir = basedir
        self._setup_fs(None)
        self._setup_datadir()
        self.meta = None
        self.buffer = RandomAccessQueue(maxlen=maxlen)
        self.logger = logger
        self.ancestor_meta = None
        if ancestor is not None:
            # Load ancestor as preloaded data
            meta = self._load_ancestor(ancestor, maxlen)
            self.ancestor_meta = meta

        # Load or create meta file and share the meta object
        self.meta_file = PersistentRandomAccessQueue._meta_file_name(
            self.basedir)
        self._load_meta(ancestor, maxlen)

        if self.fs.exists(self.datadir):
            reader = _ChunkReader(self.datadir, self.fs)
            self.gen = reader.read_chunks(maxlen, self.buffer)

        else:
            self.gen = 0
            self.fs.makedirs(self.datadir, exist_ok=True)

        self.tail = _ChunkWriter(self.datadir,
                                 self.gen,
                                 self.chunk_size,
                                 self.fs,
                                 do_pickle=True)  # Last chunk to be appended
        self.gen += 1

        if self.logger:
            self.logger.info("Initial buffer size=%d, next gen=%d",
                             len(self.buffer), self.gen)
Example #8
0
 def __init__(self, capacity=None):
     self.current_episode = collections.defaultdict(list)
     self.episodic_memory = RandomAccessQueue()
     self.memory = RandomAccessQueue()
     self.capacity = capacity
Example #9
0
class EpisodicReplayBuffer(AbstractEpisodicReplayBuffer):
    def __init__(self, capacity=None):
        self.current_episode = collections.defaultdict(list)
        self.episodic_memory = RandomAccessQueue()
        self.memory = RandomAccessQueue()
        self.capacity = capacity

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               env_id=0,
               **kwargs):
        current_episode = self.current_episode[env_id]
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        current_episode.append(experience)
        if is_state_terminal:
            self.stop_current_episode(env_id=env_id)

    def sample(self, n):
        assert len(self.memory) >= n
        return self.memory.sample(n)

    def sample_episodes(self, n_episodes, max_len=None):
        assert len(self.episodic_memory) >= n_episodes
        episodes = self.episodic_memory.sample(n_episodes)
        if max_len is not None:
            return [random_subseq(ep, max_len) for ep in episodes]
        else:
            return episodes

    def __len__(self):
        return len(self.memory)

    @property
    def n_episodes(self):
        return len(self.episodic_memory)

    def save(self, filename):
        with open(filename, "wb") as f:
            pickle.dump((self.memory, self.episodic_memory), f)

    def load(self, filename):
        with open(filename, "rb") as f:
            memory = pickle.load(f)
        if isinstance(memory, tuple):
            self.memory, self.episodic_memory = memory
        else:
            # Load v0.2
            # FIXME: The code works with EpisodicReplayBuffer
            # but not with PrioritizedEpisodicReplayBuffer
            self.memory = RandomAccessQueue(memory)
            self.episodic_memory = RandomAccessQueue()

            # Recover episodic_memory with best effort.
            episode = []
            for item in self.memory:
                episode.append(item)
                if item["is_state_terminal"]:
                    self.episodic_memory.append(episode)
                    episode = []

    def stop_current_episode(self, env_id=0):
        current_episode = self.current_episode[env_id]
        if current_episode:
            self.episodic_memory.append(current_episode)
            for transition in current_episode:
                self.memory.append([transition])
            self.current_episode[env_id] = []
            while self.capacity is not None and len(
                    self.memory) > self.capacity:
                discarded_episode = self.episodic_memory.popleft()
                for _ in range(len(discarded_episode)):
                    self.memory.popleft()
        assert not self.current_episode[env_id]
class PrioritizedEpisodicReplayBuffer(EpisodicReplayBuffer, PriorityWeightError):
    def __init__(
        self,
        capacity=None,
        alpha=0.6,
        beta0=0.4,
        betasteps=2e5,
        eps=1e-8,
        normalize_by_max=True,
        default_priority_func=None,
        uniform_ratio=0,
        wait_priority_after_sampling=True,
        return_sample_weights=True,
        error_min=None,
        error_max=None,
    ):
        self.current_episode = collections.defaultdict(list)
        self.episodic_memory = PrioritizedBuffer(
            capacity=None, wait_priority_after_sampling=wait_priority_after_sampling
        )
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.capacity_left = capacity
        self.default_priority_func = default_priority_func
        self.uniform_ratio = uniform_ratio
        self.return_sample_weights = return_sample_weights
        PriorityWeightError.__init__(
            self,
            alpha,
            beta0,
            betasteps,
            eps,
            normalize_by_max,
            error_min=error_min,
            error_max=error_max,
        )

    def sample_episodes(self, n_episodes, max_len=None):
        """Sample n unique samples from this replay buffer"""
        assert len(self.episodic_memory) >= n_episodes
        episodes, probabilities, min_prob = self.episodic_memory.sample(
            n_episodes, uniform_ratio=self.uniform_ratio
        )
        if max_len is not None:
            episodes = [random_subseq(ep, max_len) for ep in episodes]
        if self.return_sample_weights:
            weights = self.weights_from_probabilities(probabilities, min_prob)
            return episodes, weights
        else:
            return episodes

    def update_errors(self, errors):
        self.episodic_memory.set_last_priority(self.priority_from_errors(errors))

    def stop_current_episode(self, env_id=0):
        current_episode = self.current_episode[env_id]
        if current_episode:
            if self.default_priority_func is not None:
                priority = self.default_priority_func(current_episode)
            else:
                priority = None
            self.memory.extend(current_episode)
            self.episodic_memory.append(current_episode, priority=priority)
            if self.capacity_left is not None:
                self.capacity_left -= len(current_episode)
            self.current_episode[env_id] = []
            while self.capacity_left is not None and self.capacity_left < 0:
                discarded_episode = self.episodic_memory.popleft()
                self.capacity_left += len(discarded_episode)
        assert not self.current_episode[env_id]
class PersistentRandomAccessQueue(object):
    """Persistent data structure for replay buffer

    Features:
    - Perfectly compatible with collections.RandomAccessQueue
    - Persist replay buffer data on storage to survive restart
      - [todo] remove popleft'd data from disk
    - Reuse replay buffer data to another training session
      - Track back the replay buffer lineage
    Non-it-is-for:
    - Extend replay buffer by spilling to the disk

    TODOs
    - Optimize writes; buffered writes with threads or something

    Arguments:
        basedir (str): Path to the directory where replay buffer data is stored.
        maxlen (int): Max length of queue. Appended data beyond
            this limit is only removed from memory, not from storage.
        ancestor (str): Path to pre-generated replay buffer.
        logger: logger

    """
    def __init__(self, basedir, maxlen, *, ancestor=None, logger=None):
        assert maxlen is None or maxlen > 0
        self.basedir = basedir
        self._setup_fs(None)
        self._setup_datadir()
        self.meta = None
        self.buffer = RandomAccessQueue(maxlen=maxlen)
        self.logger = logger
        self.ancestor_meta = None
        if ancestor is not None:
            # Load ancestor as preloaded data
            meta = self._load_ancestor(ancestor, maxlen)
            self.ancestor_meta = meta

        # Load or create meta file and share the meta object
        self.meta_file = PersistentRandomAccessQueue._meta_file_name(
            self.basedir)
        self._load_meta(ancestor, maxlen)

        if self.fs.exists(self.datadir):
            reader = _ChunkReader(self.datadir, self.fs)
            self.gen = reader.read_chunks(maxlen, self.buffer)

        else:
            self.gen = 0
            self.fs.makedirs(self.datadir, exist_ok=True)

        self.tail = _ChunkWriter(self.datadir,
                                 self.gen,
                                 self.chunk_size,
                                 self.fs,
                                 do_pickle=True)  # Last chunk to be appended
        self.gen += 1

        if self.logger:
            self.logger.info("Initial buffer size=%d, next gen=%d",
                             len(self.buffer), self.gen)

    def _load_meta(self, ancestor, maxlen):
        # This must be checked by single process to avoid race
        # condition where one creates and the other may detect it
        # as exisiting... process differently OTL
        if self.fs.exists(self.meta_file):
            # Load existing meta
            with self.fs.open(self.meta_file, "rb") as fp:
                self.meta = pickle.load(fp)

            # TODO: update chunksize and other properties
            assert isinstance(self.meta, dict)

            # MPI world size must be the same when it's restart
            assert (self.meta["comm_size"] == self.comm_size
                    ), "Reloading same basedir requires same comm.size"

        else:
            # Create meta from scratch
            # Timestamp from pfrl.experiments.prepare_output_dir

            ts = datetime.strftime(datetime.today(), "%Y%m%dT%H%M%S.%f")
            self.meta = dict(
                basedir=self.basedir,
                maxlen=maxlen,
                comm_size=self.comm_size,
                ancestor=ancestor,
                timestamp=ts,
                chunksize=self.chunk_size,
                trim=False,  # `trim` is reserved for future extension.
            )

            # Note: If HDFS access fails at first open, make sure
            # no ``cv2`` import fail happening - failing
            # ``opencv-python`` due to lacking ``libSM.so`` may
            # break whole dynamic library loader and thus breaks
            # other dynamic library loading (e.g. libhdfs.so)
            # which may happen here. Solution for this is to let
            # the import success, e.g. installing the lacking
            # library correctly.
            self.fs.makedirs(self.basedir, exist_ok=True)
            with self.fs.open(self.meta_file, "wb") as fp:
                pickle.dump(self.meta, fp)

    def close(self):
        self.tail.close()
        self.tail = None

    def _append(self, value):
        if self.tail.is_full():
            self.tail = _ChunkWriter(self.datadir,
                                     self.gen,
                                     self.chunk_size,
                                     self.fs,
                                     do_pickle=True)
            if self.logger:
                self.logger.info("Chunk rotated. New gen=%d", self.gen)
            self.gen += 1

        self.tail.append(value)

    # RandomAccessQueue-compat methods
    def append(self, value):
        self._append(value)
        self.buffer.append(value)

    def extend(self, xs):
        for x in xs:
            self._append(x)
        self.buffer.extend(xs)

    def __iter__(self):
        return iter(self.buffer)

    def __repr__(self):
        return "PersistentRandomAccessQueue({})".format(str(self.buffer))

    def __setitem__(self, i, x):
        raise NotImplementedError()

    def __getitem__(self, i):
        return self.buffer[i]

    def sample(self, n):
        return self.buffer.sample(n)

    def popleft(self):
        self.buffer.popleft()

    def __len__(self):
        return len(self.buffer)

    @property
    def maxlen(self):
        return self.meta["maxlen"]

    @property
    def comm_size(self):
        return 1  # Fixed to 1

    @property
    def comm_rank(self):
        return 0

    @property
    def chunk_size(self):
        return 16 * 128 * 1024 * 1024  # Fixed: 16 * 128MB

    @staticmethod
    def _meta_file_name(dirname):
        return os.path.join(dirname, "meta.pkl")

    def _setup_fs(self, fs):
        # In __init__() fs is fixed to None, but this is reserved for
        # future extension support non-posix file systems such as HDFS
        if fs is None:
            if _chainerio_available:
                # _chainerio_available must be None for now
                raise NotImplementedError(
                    "Internal Error: chainerio support is not yet implemented")
            else:
                # When chainerio is not installed
                self.fs = _VanillaFS(open=open,
                                     exists=os.path.exists,
                                     makedirs=os.makedirs)
        else:
            self.fs = fs

    def _setup_datadir(self):
        # the name "rank0" means that the process is the rank 0
        # in a parallel processing process group
        # It is fixed to 'rank0' and prepared for future extension.
        self.datadir = os.path.join(self.basedir, "rank0")

    def _load_ancestor(self, ancestor, num_data_needed):
        """Simple implementation"""
        ancestor_metafile = PersistentRandomAccessQueue._meta_file_name(
            ancestor)
        with self.fs.open(ancestor_metafile, "rb") as fp:
            meta = pickle.load(fp)
            assert isinstance(meta, dict)
        if self.logger:
            self.logger.info("Loading buffer data from %s", ancestor)

        datadirs = []
        saved_comm_size = meta["comm_size"]
        n_data_dirs = (saved_comm_size + self.comm_size - 1) // self.comm_size
        data_dir_i = self.comm_rank
        for _ in range(n_data_dirs):
            data_dir_i = data_dir_i % saved_comm_size
            datadirs.append(os.path.join(ancestor,
                                         "rank{}".format(data_dir_i)))
            data_dir_i += self.comm_size

        length = 0
        for datadir in datadirs:
            reader = _ChunkReader(datadir, self.fs)
            gen = 0
            while True:
                filename = os.path.join(datadir,
                                        _INDEX_FILENAME_FORMAT.format(gen))
                if not self.fs.exists(filename):
                    break
                if self.logger:
                    self.logger.debug("read_chunk_index from %s, gen=%d",
                                      datadir, gen)
                length += len(list(reader.read_chunk_index(gen)))
                gen += 1

        if length < num_data_needed and meta["ancestor"] is not None:
            self._load_ancestor(meta["ancestor"], num_data_needed - length)

        for datadir in datadirs:
            reader = _ChunkReader(datadir, self.fs)
            rank_data = []
            maxlen = num_data_needed - len(self.buffer)
            if maxlen <= 0:
                break
            _ = reader.read_chunks(maxlen, rank_data)
            if self.logger:
                self.logger.info("%d data loaded to buffer (rank=%d)",
                                 len(rank_data), self.comm_rank)
            self.buffer.extend(rank_data)
        return meta
Example #12
0
class TestRandomAccessQueue:
    @pytest.fixture(autouse=True)
    def setUp(self, maxlen, init_seq):
        self.maxlen = maxlen
        self.init_seq = init_seq
        if self.init_seq:
            self.y_queue = RandomAccessQueue(self.init_seq, maxlen=self.maxlen)
            self.t_queue = collections.deque(self.init_seq, maxlen=self.maxlen)
        else:
            self.y_queue = RandomAccessQueue(maxlen=self.maxlen)
            self.t_queue = collections.deque(maxlen=self.maxlen)

    def test1(self):
        self.check_all()

        self.check_popleft()
        self.do_append(10)
        self.check_all()

        self.check_popleft()
        self.check_popleft()
        self.do_append(11)
        self.check_all()

        # test negative indices
        n = len(self.t_queue)
        for i in range(-n, 0):
            self.check_getitem(i)

        for k in range(4):
            self.do_extend(range(k))
            self.check_all()

        for k in range(4):
            self.check_popleft()
            self.do_extend(range(k))
            self.check_all()

        for k in range(10):
            self.do_append(20 + k)
            self.check_popleft()
            self.check_popleft()
            self.check_all()

        for _ in range(100):
            self.check_popleft()

    def check_all(self):
        self.check_len()
        n = len(self.t_queue)
        for i in range(n):
            self.check_getitem(i)

    def check_len(self):
        assert len(self.y_queue) == len(self.t_queue)

    def check_getitem(self, i):
        assert self.y_queue[i] == self.t_queue[i]

    def do_setitem(self, i, x):
        self.y_queue[i] = x
        self.t_queue[i] = x

    def do_append(self, x):
        self.y_queue.append(x)
        self.t_queue.append(x)

    def do_extend(self, xs):
        self.y_queue.extend(xs)
        self.t_queue.extend(xs)

    def check_popleft(self):
        try:
            t = self.t_queue.popleft()
        except IndexError:
            with pytest.raises(IndexError):
                self.y_queue.popleft()
        else:
            assert self.y_queue.popleft() == t
Example #13
0
class ReplayBuffer(replay_buffer.AbstractReplayBuffer):
    """Experience Replay Buffer

    As described in
    https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf.

    Args:
        capacity (int): capacity in terms of number of transitions
        num_steps (int): Number of timesteps per stored transition
            (for N-step updates)
    """

    # Implements AbstractReplayBuffer.capacity
    capacity: Optional[int] = None

    def __init__(self, capacity: Optional[int] = None, num_steps: int = 1):
        self.capacity = capacity
        assert num_steps > 0
        self.num_steps = num_steps
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.last_n_transitions: collections.defaultdict = collections.defaultdict(
            lambda: collections.deque([], maxlen=num_steps))

    def append(self,
               state,
               action,
               reward,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               env_id=0,
               **kwargs):
        last_n_transitions = self.last_n_transitions[env_id]
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        last_n_transitions.append(experience)
        if is_state_terminal:
            while last_n_transitions:
                self.memory.append(list(last_n_transitions))
                del last_n_transitions[0]
            assert len(last_n_transitions) == 0
        else:
            if len(last_n_transitions) == self.num_steps:
                self.memory.append(list(last_n_transitions))

    def stop_current_episode(self, env_id=0):
        last_n_transitions = self.last_n_transitions[env_id]
        # if n-step transition hist is not full, add transition;
        # if n-step hist is indeed full, transition has already been added;
        if 0 < len(last_n_transitions) < self.num_steps:
            self.memory.append(list(last_n_transitions))
        # avoid duplicate entry
        if 0 < len(last_n_transitions) <= self.num_steps:
            del last_n_transitions[0]
        while last_n_transitions:
            self.memory.append(list(last_n_transitions))
            del last_n_transitions[0]
        assert len(last_n_transitions) == 0

    def sample(self, num_experiences):
        assert len(self.memory) >= num_experiences
        return self.memory.sample(num_experiences)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, "wb") as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, "rb") as f:
            self.memory = pickle.load(f)
        if isinstance(self.memory, collections.deque):
            # Load v0.2
            self.memory = RandomAccessQueue(self.memory,
                                            maxlen=self.memory.maxlen)
Example #14
0
class EVAReplayBuffer(replay_buffer.AbstractReplayBuffer):
    """Experience Replay Buffer

    As described in
    https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf.

    In addition to the normal replay buffer, storing features.

    Args:
        capacity (int): capacity in terms of number of transitions
        num_steps (int): Number of timesteps per stored transition
            (for N-step updates)
    """

    # Implements AbstractReplayBuffer.capacity
    capacity: Optional[int] = None

    def __init__(self, capacity, n_dim=256, n_neighbors=5, num_steps=1):
        self.capacity = capacity
        assert num_steps > 0
        self.num_steps = num_steps
        self.memory = RandomAccessQueue(maxlen=capacity)
        self.h_memory = lkb(capacity=capacity, n_dim=n_dim)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        # self.current_embeddings = torch.empty(0, n_dim, device=self.device, dtype=torch.float32)
        # self.current_embeddings = []
        self.last_n_transitions = collections.defaultdict(
            lambda: collections.deque([], maxlen=num_steps))
        self.n_neighbors = n_neighbors

    def append(self,
               state,
               action,
               reward,
               feature: torch.tensor,
               next_state=None,
               next_action=None,
               is_state_terminal=False,
               env_id=0,
               **kwargs):
        last_n_transitions = self.last_n_transitions[env_id]
        experience = dict(state=state,
                          action=action,
                          reward=reward,
                          feature=feature,
                          next_state=next_state,
                          next_action=next_action,
                          is_state_terminal=is_state_terminal,
                          **kwargs)
        last_n_transitions.append(experience)
        if is_state_terminal:
            while last_n_transitions:
                self.memory.append(list(last_n_transitions))
                embeddings = [m['feature'] for m in last_n_transitions]
                self.update_feature_arr(embeddings)
                del last_n_transitions[0]
            assert len(last_n_transitions) == 0
        else:
            if len(last_n_transitions) == self.num_steps:
                self.memory.append(list(last_n_transitions))
                embeddings = [m['feature'] for m in last_n_transitions]
                self.update_feature_arr(embeddings)
        assert len(self.h_memory) == len(self)

    def stop_current_episode(self, env_id=0):
        last_n_transitions = self.last_n_transitions[env_id]
        # if n-step transition hist is not full, add transition;
        # if n-step hist is indeed full, transition has already been added;
        if 0 < len(last_n_transitions) < self.num_steps:
            self.memory.append(list(last_n_transitions))
            embeddings = [m['feature'] for m in last_n_transitions]
            self.update_feature_arr(embeddings)
        # avoid duplicate entry
        if 0 < len(last_n_transitions) <= self.num_steps:
            del last_n_transitions[0]
        while last_n_transitions:
            self.memory.append(list(last_n_transitions))
            embeddings = [m['feature'] for m in last_n_transitions]
            self.update_feature_arr(embeddings)
            del last_n_transitions[0]
        assert len(last_n_transitions) == 0
        assert len(self.h_memory) == len(self)

    def sample(self, num_experiences):
        assert len(self.memory) >= num_experiences
        return self.memory.sample(num_experiences)

    def __len__(self):
        return len(self.memory)

    def save(self, filename):
        with open(filename, "wb") as f:
            pickle.dump(self.memory, f)

    def load(self, filename):
        with open(filename, "rb") as f:
            self.memory = pickle.load(f)
        if isinstance(self.memory, collections.deque):
            # Load v0.2
            self.memory = RandomAccessQueue(self.memory,
                                            maxlen=self.memory.maxlen)

    def update_feature_arr(self, embeddings: List[np.ndarray]):
        if len(embeddings) > 0:
            # list -> numpy
            added = np.asarray(embeddings, dtype=np.float32)
            # numpy -> Tensor
            added = torch.from_numpy(added)
            self.h_memory.append(added)

        assert len(self.h_memory) == len(self)

    def lookup(self, target_h, max_len):

        assert len(self.h_memory) == len(self)

        target_h = torch.from_numpy(target_h).clone()
        start_indices = self.h_memory.search(target_h.reshape(1, -1),
                                             self.n_neighbors)

        trajectory_list = []
        for start_index in start_indices:
            trajectory = []
            for sub_sequence in range(max_len):
                step = self.memory[start_index + sub_sequence]
                trajectory.append(step[0])
                if step[0]["is_state_terminal"]:
                    break
                if (start_index + sub_sequence) == (len(self.memory) - 1):
                    break
            trajectory_list.append(trajectory)

        return trajectory_list