Ejemplo n.º 1
0
    def extract_batch(self, T_idxs, B_idxs, T):
        """Return full sequence of each field in `agent_inputs` (e.g. `observation`),
        including all timesteps for the main sequence and for the target sequence in
        one array; many timesteps will likely overlap, so the algorithm and make
        sub-sequences by slicing on device, for reduced memory usage.

        Enforces that input `T_idxs` align with RNN state interval.

        Uses helper function ``extract_sequences()`` to retrieve samples of
        length ``T`` starting at locations ``[T_idxs,B_idxs]``, so returned
        data batch has leading dimensions ``[T,len(B_idxs)]``."""
        s, rsi = self.samples, self.rnn_state_interval
        if rsi > 1:
            assert np.all(np.asarray(T_idxs) % rsi == 0)
            init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs]
        elif rsi == 1:
            init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs]
        else:  # rsi == 0
            init_rnn_state = None
        batch = SamplesFromReplay(
            all_observation=self.extract_observation(T_idxs, B_idxs,
                                                     T + self.n_step_return),
            all_action=buffer_func(
                s.action, extract_sequences, T_idxs - 1, B_idxs,
                T + self.n_step_return),  # Starts at prev_action.
            all_reward=extract_sequences(
                s.reward, T_idxs - 1, B_idxs,
                T + self.n_step_return),  # Only prev_reward (agent + target).
            return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T),
            done=extract_sequences(s.done, T_idxs, B_idxs, T),
            done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T),
            init_rnn_state=init_rnn_state,  # (Same state for agent and target.)
        )
        # NOTE: Algo might need to make zero prev_action/prev_reward depending on done.
        return torchify_buffer(batch)
Ejemplo n.º 2
0
 def extract_batch(self, T_idxs, B_idxs, T):
     """Return full sequence of each field which encompasses all subsequences
     to be used, so algorithm can make sub-sequences by slicing on device,
     for reduced memory usage."""
     s, rsi = self.samples, self.rnn_state_interval
     if rsi > 1:
         assert np.all(np.asarray(T_idxs) % rsi == 0)
         init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs]
     elif rsi == 1:
         init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs]
     else:  # rsi == 0
         init_rnn_state = None
     batch = SamplesFromReplay(
         all_observation=self.extract_observation(T_idxs, B_idxs,
                                                  T + self.n_step_return),
         all_action=buffer_func(
             s.action, extract_sequences, T_idxs - 1, B_idxs,
             T + self.n_step_return),  # Starts at prev_action.
         all_reward=extract_sequences(
             s.reward, T_idxs - 1, B_idxs,
             T + self.n_step_return),  # Only prev_reward (agent + target).
         return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T),
         done=extract_sequences(s.done, T_idxs, B_idxs, T),
         done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T),
         init_rnn_state=init_rnn_state,  # (Same state for agent and target.)
     )
     # NOTE: Algo might need to make zero prev_action/prev_reward depending on done.
     return torchify_buffer(batch)
Ejemplo n.º 3
0
    def sample_batch(self, batch_B):
        while True:
            sampled_indices = False
            try:
                self._async_pull()  # Updates from writers.
                batch_T = self.batch_T
                T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T)
                sampled_indices = True
                if self.rnn_state_interval > 1:
                    T_idxs = T_idxs * self.rnn_state_interval

                batch = self.extract_batch(T_idxs, B_idxs, self.batch_T)
                policies = torch.from_numpy(
                    extract_sequences(self.samples.policy_probs, T_idxs,
                                      B_idxs,
                                      self.batch_T + self.n_step_return + 1))
                values = torch.from_numpy(
                    extract_sequences(self.samples.value, T_idxs, B_idxs,
                                      self.batch_T + self.n_step_return + 1))
                batch = list(batch)
                batch = SamplesFromReplayExt(*batch,
                                             policy_probs=policies,
                                             values=values)
                return self.sanitize_batch(batch)
            except:
                print("FAILED TO LOAD BATCH")
                if sampled_indices:
                    print("B_idxs:", B_idxs, flush=True)
                    print("T_idxs:", T_idxs, flush=True)
                    print("Batch_T:", self.batch_T, flush=True)
                    print("Buffer T:", self.T, flush=True)
Ejemplo n.º 4
0
 def extract_batch(self, T_idxs, B_idxs, T):
     s = self.samples
     batch = SamplesFromReplay(
         observation=self.extract_observation(T_idxs, B_idxs, T),
         action=buffer_func(s.action, extract_sequences, T_idxs, B_idxs, T),
         reward=extract_sequences(s.reward, T_idxs, B_idxs, T),
         done=extract_sequences(s.done, T_idxs, B_idxs, T),
     )
     return torchify_buffer(batch)
Ejemplo n.º 5
0
    def sample_batch(self, batch_B):
        while True:
            try:
                self._async_pull()  # Updates from writers.
                (T_idxs, B_idxs), priorities = self.priority_tree.sample(
                    batch_B, unique=self.unique)
                sampled_indices = True
                if self.rnn_state_interval > 1:
                    T_idxs = T_idxs * self.rnn_state_interval

                batch = self.extract_batch(T_idxs, B_idxs, self.batch_T)

            except Exception as _:
                print("FAILED TO LOAD BATCH")
                traceback.print_exc()
                if sampled_indices:
                    print("B_idxs:", B_idxs, flush=True)
                    print("T_idxs:", T_idxs, flush=True)
                    print("Batch_T:", self.batch_T, flush=True)
                    print("Buffer T:", self.T, flush=True)

            is_weights = (1. / (priorities + 1e-5)) ** self.beta
            is_weights /= max(is_weights)  # Normalize.
            is_weights = torchify_buffer(is_weights).float()

            elapsed_iters = self.t + self.T - T_idxs % self.T
            elapsed_samples = self.B*(elapsed_iters)
            values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1))
            batch = SamplesFromReplayPriExt(*batch,
                                            values=values,
                                            is_weights=is_weights,
                                            age=elapsed_samples)
            if self.batch_T > 1:
                batch = self.sanitize_batch(batch)
            return batch
Ejemplo n.º 6
0
    def sample_batch(self, batch_B):
        while True:
            try:
                self._async_pull()  # Updates from writers.
                batch_T = self.batch_T
                T_idxs, B_idxs = self.sample_idxs(batch_B, batch_T)
                sampled_indices = True
                if self.rnn_state_interval > 1:
                    T_idxs = T_idxs * self.rnn_state_interval
                batch = self.extract_batch(T_idxs, B_idxs, self.batch_T)

            except Exception as _:
                print("FAILED TO LOAD BATCH")
                if sampled_indices:
                    print("B_idxs:", B_idxs, flush=True)
                    print("T_idxs:", T_idxs, flush=True)
                    print("Batch_T:", self.batch_T, flush=True)
                    print("Buffer T:", self.T, flush=True)

            elapsed_iters = self.t + self.T - T_idxs % self.T
            elapsed_samples = self.B*(elapsed_iters)
            values = torch.from_numpy(extract_sequences(self.samples.value, T_idxs, B_idxs, self.batch_T+self.n_step_return+1))
            batch = SamplesFromReplayExt(*batch, values=values, age=elapsed_samples)
            if self.batch_T > 1:
                batch = self.sanitize_batch(batch)
            return batch
Ejemplo n.º 7
0
 def extract_batch(self, T_idxs, B_idxs):
     T = self.replay_T
     all_action = buffer_func(self.samples.action, extract_sequences,
                              T_idxs - 1, B_idxs, T + 1)
     all_reward = extract_sequences(self.samples.reward, T_idxs - 1, B_idxs,
                                    T + 1)
     batch = SamplesFromReplay(
         observation=self.extract_observation(T_idxs, B_idxs),
         action=all_action[1:],
         reward=all_reward[1:],
         done=extract_sequences(self.samples.done, T_idxs, B_idxs, T),
         prev_action=all_action[:-1],
         prev_reward=all_reward[:-1],
     )
     if self.pixel_control_buffer is not None:
         pixctl_return = extract_sequences(
             self.pixel_control_buffer["return_"], T_idxs, B_idxs, T)
         batch = SamplesFromReplayPC(*batch, pixctl_return=pixctl_return)
     return torchify_buffer(batch)