def append(self, observation, action, reward, terminal, training=True): """ Add new sample. Parameters ---------- observation (dict): Observation returned by environment action (int): Action taken to obtain this observation reward (float): Reward obtained by taking this action terminal (boolean): Is the state terminal """ super(PrioritizedExperience, self).append(observation, action, reward, terminal, training=training) if training: self.recent_experiences.append( [observation, action, reward, terminal]) state0 = [] state1 = [] reward = 0 for i in range(len(self.recent_experiences) - 1): state0.append(self.recent_experiences[i][0]) reward += self.recent_experiences[i][2] while len(state0) < self.window_length: state0.insert(0, zeroed_observation(observation)) state1 = [np.copy(x) for x in state0[1:]] state1.append(observation) if len(self.recent_experiences) < 2: act = action terminal = False else: act = self.recent_experiences[-2][1] terminal = self.recent_experiences[-2][3] assert len(state0) == self.window_length, "{} =/= {}".format( len(state0), self.window_length) assert len(state1) == len(state0) experience = Experience(state0=state0, action=act, reward=reward, state1=state1, terminal1=terminal) self.tree.add(experience, self.max_priority**self.alpha) if len(self.recent_terminals) > 2: if self.recent_terminals[-2] is True: # previous step ended the episode self.recent_experiences.clear()
def get_recent_state(self, current_observation): state0 = [] state1 = [] for i in range(len(self.recent_experiences) - 1): state0.append(self.recent_experiences[i][0]) while len(state0) < self.window_length: state0.insert(0, zeroed_observation(current_observation)) state1 = [np.copy(x) for x in state0[1:]] state1.append(current_observation) return state1
def sample(self, net, batch_size): """ Returns a randomized batch of experiences for an ensemble member Args: net (int): Index of ensemble member batch_size (int): Size of the batch Returns: A list of random experiences """ # It is not possible to tell whether the first state in the memory is terminal, because it # would require access to the "terminal" flag associated to the previous state. As a result # we will never return this first state (only using `self.terminals[0]` to know whether the # second state is terminal). # In addition we need enough entries to fill the desired window length. assert self.nb_entries >= self.window_length + 2, 'not enough entries in the memory' # Sample random indexes for the specified ensemble member batch_idxs = self.sample_batch_idxs(net, batch_size) assert np.min(batch_idxs) >= self.window_length + 1 assert np.max(batch_idxs) < self.nb_entries assert len(batch_idxs) == batch_size # Create experiences experiences = [] for idx in batch_idxs: terminal0 = self.terminals[idx - 2] while terminal0: # Skip this transition because the environment was reset here. Select a new, random # transition and use this instead. This may cause the batch to contain the same # transition twice. # idx = sample_batch_indexes(self.window_length + 1, self.nb_entries, size=1)[0] idx = self.sample_batch_idxs(net, 1)[0] terminal0 = self.terminals[idx - 2] assert self.window_length + 1 <= idx < self.nb_entries # This code is slightly complicated by the fact that subsequent observations might be # from different episodes. We ensure that an experience never spans multiple episodes. # This is probably not that important in practice but it seems cleaner. state0 = [self.observations[idx - 1]] for offset in range(0, self.window_length - 1): current_idx = idx - 2 - offset assert current_idx >= 1 current_terminal = self.terminals[current_idx - 1] if current_terminal and not self.ignore_episode_boundaries: # The previously handled observation was terminal, don't add the current one. # Otherwise we would leak into a different episode. break state0.insert(0, self.observations[current_idx]) while len(state0) < self.window_length: state0.insert(0, zeroed_observation(state0[0])) action = self.actions[idx - 1] reward = self.rewards[idx - 1] terminal1 = self.terminals[idx - 1] # Okay, now we need to create the follow-up state. This is state0 shifted on timestep # to the right. Again, we need to be careful to not include an observation from the next # episode if the last state is terminal. state1 = [np.copy(x) for x in state0[1:]] state1.append(self.observations[idx]) assert len(state0) == self.window_length assert len(state1) == len(state0) experiences.append(Experience(state0=state0, action=action, reward=reward, state1=state1, terminal1=terminal1)) assert len(experiences) == batch_size return experiences
def sample(self, batch_size, batch_idxs=None): if batch_idxs is None: # Draw random indexes such that we have at least a single entry before each # index. batch_idxs, pri_idxs = self.sample_batch_indexes(0, self.nb_entries - 1, size=batch_size) else: pri_idxs = [None] * batch_size batch_idxs = np.array(batch_idxs) + 1 assert np.min(batch_idxs) >= 1 assert np.max( batch_idxs ) < self.nb_entries, f"{np.max(batch_idxs)} < {self.nb_entries}" assert len(batch_idxs) == batch_size assert len(pri_idxs) == batch_size # Create experiences experiences = [] for idx, pri_idx in zip(batch_idxs, pri_idxs): terminal0 = self.terminals[idx - 2] if idx >= 2 else False while terminal0: # Skip this transition because the environment was reset here. Select a new, random # transition and use this instead. This may cause the batch to contain the same # transition twice. idx_, pri_idx_ = self.sample_batch_indexes(0, self.nb_entries - 1, size=1) idx, pri_idx = idx_[0] + 1, pri_idx_[0] terminal0 = self.terminals[idx - 2] if idx >= 2 else False assert 1 <= idx < self.nb_entries # This code is slightly complicated by the fact that subsequent observations might be # from different episodes. We ensure that an experience never spans multiple episodes. # This is probably not that important in practice but it seems cleaner. state0 = [self.observations[idx - 1]] for offset in range(0, self.window_length - 1): current_idx = idx - 2 - offset current_terminal = self.terminals[ current_idx - 1] if current_idx - 1 > 0 else False if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal): # The previously handled observation was terminal, don't add the current one. # Otherwise we would leak into a different episode. break state0.insert(0, self.observations[current_idx]) while len(state0) < self.window_length: state0.insert(0, zeroed_observation(state0[0])) action = self.actions[idx - 1] reward = self.rewards[idx - 1] terminal1 = self.terminals[idx - 1] # Okay, now we need to create the follow-up state. This is state0 shifted on timestep # to the right. Again, we need to be careful to not include an observation from the next # episode if the last state is terminal. state1 = [np.copy(x) for x in state0[1:]] state1.append(self.observations[idx]) assert len(state0) == self.window_length assert len(state1) == len(state0) experiences.append( Experience(state0=state0, action=action, reward=reward, state1=state1, terminal1=terminal1, pri_idx=pri_idx)) assert len(experiences) == batch_size return experiences