def extract_batch(self, T_idxs, B_idxs, T): """Return full sequence of each field in `agent_inputs` (e.g. `observation`), including all timesteps for the main sequence and for the target sequence in one array; many timesteps will likely overlap, so the algorithm and make sub-sequences by slicing on device, for reduced memory usage. Enforces that input `T_idxs` align with RNN state interval. Uses helper function ``extract_sequences()`` to retrieve samples of length ``T`` starting at locations ``[T_idxs,B_idxs]``, so returned data batch has leading dimensions ``[T,len(B_idxs)]``.""" s, rsi = self.samples, self.rnn_state_interval if rsi > 1: assert np.all(np.asarray(T_idxs) % rsi == 0) init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs] elif rsi == 1: init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs] else: # rsi == 0 init_rnn_state = None batch = SamplesFromReplay( all_observation=self.extract_observation(T_idxs, B_idxs, T + self.n_step_return), all_action=buffer_func( s.action, extract_sequences, T_idxs - 1, B_idxs, T + self.n_step_return), # Starts at prev_action. all_reward=extract_sequences( s.reward, T_idxs - 1, B_idxs, T + self.n_step_return), # Only prev_reward (agent + target). return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T), done=extract_sequences(s.done, T_idxs, B_idxs, T), done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T), init_rnn_state=init_rnn_state, # (Same state for agent and target.) ) # NOTE: Algo might need to make zero prev_action/prev_reward depending on done. return torchify_buffer(batch)
def extract_batch(self, T_idxs, B_idxs, T): """Return full sequence of each field which encompasses all subsequences to be used, so algorithm can make sub-sequences by slicing on device, for reduced memory usage.""" s, rsi = self.samples, self.rnn_state_interval if rsi > 1: assert np.all(np.asarray(T_idxs) % rsi == 0) init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs] elif rsi == 1: init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs] else: # rsi == 0 init_rnn_state = None batch = SamplesFromReplay( all_observation=self.extract_observation(T_idxs, B_idxs, T + self.n_step_return), all_action=buffer_func( s.action, extract_sequences, T_idxs - 1, B_idxs, T + self.n_step_return), # Starts at prev_action. all_reward=extract_sequences( s.reward, T_idxs - 1, B_idxs, T + self.n_step_return), # Only prev_reward (agent + target). return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T), done=extract_sequences(s.done, T_idxs, B_idxs, T), done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T), init_rnn_state=init_rnn_state, # (Same state for agent and target.) ) # NOTE: Algo might need to make zero prev_action/prev_reward depending on done. return torchify_buffer(batch)
def sample_batch(self, batch_B): while True: sampled_indices = False try: self._async_pull() # Updates from writers. batch_T = self.batch_T T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T) sampled_indices = True if self.rnn_state_interval > 1: T_idxs = T_idxs * self.rnn_state_interval batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) policies = torch.from_numpy( extract_sequences(self.samples.policy_probs, T_idxs, B_idxs, self.batch_T + self.n_step_return + 1)) values = torch.from_numpy( extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T + self.n_step_return + 1)) batch = list(batch) batch = SamplesFromReplayExt(*batch, policy_probs=policies, values=values) return self.sanitize_batch(batch) except: print("FAILED TO LOAD BATCH") if sampled_indices: print("B_idxs:", B_idxs, flush=True) print("T_idxs:", T_idxs, flush=True) print("Batch_T:", self.batch_T, flush=True) print("Buffer T:", self.T, flush=True)
def extract_batch(self, T_idxs, B_idxs, T): s = self.samples batch = SamplesFromReplay( observation=self.extract_observation(T_idxs, B_idxs, T), action=buffer_func(s.action, extract_sequences, T_idxs, B_idxs, T), reward=extract_sequences(s.reward, T_idxs, B_idxs, T), done=extract_sequences(s.done, T_idxs, B_idxs, T), ) return torchify_buffer(batch)
def sample_batch(self, batch_B): while True: try: self._async_pull() # Updates from writers. (T_idxs, B_idxs), priorities = self.priority_tree.sample( batch_B, unique=self.unique) sampled_indices = True if self.rnn_state_interval > 1: T_idxs = T_idxs * self.rnn_state_interval batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) except Exception as _: print("FAILED TO LOAD BATCH") traceback.print_exc() if sampled_indices: print("B_idxs:", B_idxs, flush=True) print("T_idxs:", T_idxs, flush=True) print("Batch_T:", self.batch_T, flush=True) print("Buffer T:", self.T, flush=True) is_weights = (1. / (priorities + 1e-5)) ** self.beta is_weights /= max(is_weights) # Normalize. is_weights = torchify_buffer(is_weights).float() elapsed_iters = self.t + self.T - T_idxs % self.T elapsed_samples = self.B*(elapsed_iters) values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1)) batch = SamplesFromReplayPriExt(*batch, values=values, is_weights=is_weights, age=elapsed_samples) if self.batch_T > 1: batch = self.sanitize_batch(batch) return batch
def sample_batch(self, batch_B): while True: try: self._async_pull() # Updates from writers. batch_T = self.batch_T T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T) sampled_indices = True if self.rnn_state_interval > 1: T_idxs = T_idxs * self.rnn_state_interval batch = self.extract_batch(T_idxs, B_idxs, self.batch_T) except Exception as _: print("FAILED TO LOAD BATCH") if sampled_indices: print("B_idxs:", B_idxs, flush=True) print("T_idxs:", T_idxs, flush=True) print("Batch_T:", self.batch_T, flush=True) print("Buffer T:", self.T, flush=True) elapsed_iters = self.t + self.T - T_idxs % self.T elapsed_samples = self.B*(elapsed_iters) values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1)) batch = SamplesFromReplayExt(*batch, values=values, age=elapsed_samples) if self.batch_T > 1: batch = self.sanitize_batch(batch) return batch
def extract_batch(self, T_idxs, B_idxs): T = self.replay_T all_action = buffer_func(self.samples.action, extract_sequences, T_idxs - 1, B_idxs, T + 1) all_reward = extract_sequences(self.samples.reward, T_idxs - 1, B_idxs, T + 1) batch = SamplesFromReplay( observation=self.extract_observation(T_idxs, B_idxs), action=all_action[1:], reward=all_reward[1:], done=extract_sequences(self.samples.done, T_idxs, B_idxs, T), prev_action=all_action[:-1], prev_reward=all_reward[:-1], ) if self.pixel_control_buffer is not None: pixctl_return = extract_sequences( self.pixel_control_buffer["return_"], T_idxs, B_idxs, T) batch = SamplesFromReplayPC(*batch, pixctl_return=pixctl_return) return torchify_buffer(batch)