Example #1
0
    def extract_batch(self, T_idxs, B_idxs, T):
        """Return full sequence of each field in `agent_inputs` (e.g. `observation`),
        including all timesteps for the main sequence and for the target sequence in
        one array; many timesteps will likely overlap, so the algorithm and make
        sub-sequences by slicing on device, for reduced memory usage.

        Enforces that input `T_idxs` align with RNN state interval.

        Uses helper function ``extract_sequences()`` to retrieve samples of
        length ``T`` starting at locations ``[T_idxs,B_idxs]``, so returned
        data batch has leading dimensions ``[T,len(B_idxs)]``."""
        s, rsi = self.samples, self.rnn_state_interval
        if rsi > 1:
            assert np.all(np.asarray(T_idxs) % rsi == 0)
            init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs]
        elif rsi == 1:
            init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs]
        else:  # rsi == 0
            init_rnn_state = None
        batch = SamplesFromReplay(
            all_observation=self.extract_observation(T_idxs, B_idxs,
                                                     T + self.n_step_return),
            all_action=buffer_func(
                s.action, extract_sequences, T_idxs - 1, B_idxs,
                T + self.n_step_return),  # Starts at prev_action.
            all_reward=extract_sequences(
                s.reward, T_idxs - 1, B_idxs,
                T + self.n_step_return),  # Only prev_reward (agent + target).
            return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T),
            done=extract_sequences(s.done, T_idxs, B_idxs, T),
            done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T),
            init_rnn_state=init_rnn_state,  # (Same state for agent and target.)
        )
        # NOTE: Algo might need to make zero prev_action/prev_reward depending on done.
        return torchify_buffer(batch)
Example #2
0
    def forward(self,
                observation: torch.Tensor,
                prev_action: torch.Tensor = None,
                prev_state: RSSMState = None):
        if isinstance(observation, tuple):
            img_obs, state_obs = observation
            state_obs = state_obs.to(self.dtype).to(img_obs.device)
            if len(state_obs.shape) == 1:
                state_obs = state_obs.unsqueeze(0)
        else:
            img_obs = observation
            state_obs = None

        if prev_action is None:
            prev_action = torch.zeros(self.action_size,
                                      device=img_obs.device,
                                      dtype=img_obs.dtype)

        lead_dim, T, B, img_shape = infer_leading_dims(img_obs, 3)
        img_obs = img_obs.reshape(T * B, *img_shape).type(
            self.dtype) / 255.0 - 0.5
        prev_action = prev_action.reshape(T * B, -1).to(self.dtype)
        if prev_state is None:
            prev_state = self.representation.initial_state(
                prev_action.size(0),
                device=prev_action.device,
                dtype=self.dtype)
        state = self.get_state_representation(img_obs, state_obs, prev_action,
                                              prev_state)

        action, action_dist = self.policy(state)
        return_spec = ModelReturnSpec(action, state)
        return_spec = buffer_func(return_spec, restore_leading_dims, lead_dim,
                                  T, B)
        return return_spec
Example #3
0
    def extract_observation(self, T_idxs, B_idxs):
        T = self.replay_T
        if not self._is_frame_buffer:
            return buffer_func(self.samples.observation,
                extract_sequences, T_idxs, B_idxs, T)
        frames = self.samples_frames
        observation = np.empty(
            shape=(T, len(B_idxs), self.n_frames) + frames.shape[2:],  # [T,B,C,H,W]
            dtype=frames.dtype,
        )
        fm1 = self.n_frames - 1
        for i, (t, b) in enumerate(zip(T_idxs, B_idxs)):
            assert t + T <= self.T  # no wrapping allowed
            for f in range(self.n_frames):
                observation[:, i, f] = frames[t + f:t + f + T, b]

            # Populate empty (zero) frames after environment done.
            assert t - fm1 >= 0  # no wrapping allowed
            done_idxs = slice(t - fm1, t + T)
            done_fm1 = self.samples.done[done_idxs, b]
            if np.any(done_fm1):
                where_done_t = np.where(done_fm1)[0] - fm1  # Might be negative...
                for f in range(1, self.n_frames):
                    t_blanks = where_done_t + f  # ...might be > T...
                    t_blanks = t_blanks[(t_blanks >= 0) & (t_blanks < T)]  # ..don't let it wrap.
                    observation[t_blanks, i, :self.n_frames - f] = 0
        return observation
Example #4
0
 def extract_batch(self, T_idxs, B_idxs, T):
     """Return full sequence of each field which encompasses all subsequences
     to be used, so algorithm can make sub-sequences by slicing on device,
     for reduced memory usage."""
     s, rsi = self.samples, self.rnn_state_interval
     if rsi > 1:
         assert np.all(np.asarray(T_idxs) % rsi == 0)
         init_rnn_state = self.samples_prev_rnn_state[T_idxs // rsi, B_idxs]
     elif rsi == 1:
         init_rnn_state = self.samples.prev_rnn_state[T_idxs, B_idxs]
     else:  # rsi == 0
         init_rnn_state = None
     batch = SamplesFromReplay(
         all_observation=self.extract_observation(T_idxs, B_idxs,
                                                  T + self.n_step_return),
         all_action=buffer_func(
             s.action, extract_sequences, T_idxs - 1, B_idxs,
             T + self.n_step_return),  # Starts at prev_action.
         all_reward=extract_sequences(
             s.reward, T_idxs - 1, B_idxs,
             T + self.n_step_return),  # Only prev_reward (agent + target).
         return_=extract_sequences(self.samples_return_, T_idxs, B_idxs, T),
         done=extract_sequences(s.done, T_idxs, B_idxs, T),
         done_n=extract_sequences(self.samples_done_n, T_idxs, B_idxs, T),
         init_rnn_state=init_rnn_state,  # (Same state for agent and target.)
     )
     # NOTE: Algo might need to make zero prev_action/prev_reward depending on done.
     return torchify_buffer(batch)
Example #5
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs,
                                                self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     probs, value, rnn_state = self.model(*agent_inputs,
                                          self.prev_rnn_state)
     dist_info = DistInfo(prob=probs)
     if self._mode == 'sample':
         action = self.distribution.sample(dist_info)
     elif self._mode == 'eval':
         action = torch.argmax(probs, dim=-1)
     # Model handles None, but Buffer does not, make zeros if needed:
     if self.prev_rnn_state is None:
         prev_rnn_state = buffer_func(rnn_state, torch.zeros_like)
     else:
         prev_rnn_state = self.prev_rnn_state
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Example #7
0
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            int_pi, int_value, int_rnn_state = self.model_int(
                *agent_inputs, self.prev_int_rnn_state)
            dist_int_info = DistInfo(prob=int_pi)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)

        if self.dual_model:
            prev_int_rnn_state = self.prev_int_rnn_state or buffer_func(
                int_rnn_state, torch.zeros_like)
            prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose",
                                               0, 1)
            agent_info = AgentInfoRnnTwin(
                dist_info=dist_info,
                value=value,
                prev_rnn_state=prev_rnn_state,
                dist_int_info=dist_int_info,
                int_value=int_value,
                prev_int_rnn_state=prev_int_rnn_state)
        else:
            agent_info = AgentInfoRnn(dist_info=dist_info,
                                      value=value,
                                      prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        if self.dual_model:
            self.advance_int_rnn_state(int_rnn_state)
        return AgentStep(action=action, agent_info=agent_info)
Example #8
0
 def extract_batch(self, T_idxs, B_idxs, T):
     s = self.samples
     batch = SamplesFromReplay(
         observation=self.extract_observation(T_idxs, B_idxs, T),
         action=buffer_func(s.action, extract_sequences, T_idxs, B_idxs, T),
         reward=extract_sequences(s.reward, T_idxs, B_idxs, T),
         done=extract_sequences(s.done, T_idxs, B_idxs, T),
     )
     return torchify_buffer(batch)
Example #9
0
    def forward(self, observation: torch.Tensor, prev_action: torch.Tensor = None, prev_state: RSSMState = None):
        lead_dim, T, B, img_shape = infer_leading_dims(observation, 3)
        observation = observation.reshape(T * B, *img_shape).type(self.dtype) / 255.0 - 0.5
        prev_action = prev_action.reshape(T * B, -1).to(self.dtype)
        if prev_state is None:
            prev_state = self.representation.initial_state(prev_action.size(0), device=prev_action.device, dtype=self.dtype)
        state = self.get_state_representation(observation, prev_action, prev_state)

        action, action_dist = self.policy(state)
        return_spec = ModelReturnSpec(action, state)
        return_spec = buffer_func(return_spec, restore_leading_dims, lead_dim, T, B)
        return return_spec
Example #10
0
 def to_agent_step(self, output):
     """Convert the output of the NN model into step info for the agent.
     """
     q, rnn_state = output
     # q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state, action, q = buffer_to((prev_rnn_state, action, q), device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Example #11
0
 def step(self, observation, prev_action, prev_reward):
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info, value=value,
         prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Example #12
0
 def step(self, observation, prev_action, prev_reward):
     """Computes Q-values for states/observations and selects actions by
     epsilon-greedy (no grad).  Advances RNN state."""
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     q, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)  # Model handles None.
     q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state = buffer_to(prev_rnn_state, device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Example #13
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     q, rnn_state = self.model(*agent_inputs,
                               self.prev_rnn_state)  # Model handles None.
     q = q.cpu()
     action = self.distribution.sample(q)
     prev_rnn_state = self.prev_rnn_state or buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case, model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     prev_rnn_state = buffer_to(prev_rnn_state, device="cpu")
     agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
 def step(self, observation, prev_action, prev_reward):
     """"
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     model_inputs = buffer_to((observation, prev_action), device=self.device)
     action, state = self.model(*model_inputs, self.prev_rnn_state)
     action = self.exploration(action)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_state = self.prev_rnn_state or buffer_func(state, torch.zeros_like)
     self.advance_rnn_state(state)
     agent_info = DreamerAgentInfo(prev_state=prev_state)
     agent_step = AgentStep(action=action, agent_info=agent_info)
     return buffer_to(agent_step, device='cpu')
Example #15
0
 def extract_batch(self, T_idxs, B_idxs):
     T = self.replay_T
     all_action = buffer_func(self.samples.action, extract_sequences,
                              T_idxs - 1, B_idxs, T + 1)
     all_reward = extract_sequences(self.samples.reward, T_idxs - 1, B_idxs,
                                    T + 1)
     batch = SamplesFromReplay(
         observation=self.extract_observation(T_idxs, B_idxs),
         action=all_action[1:],
         reward=all_reward[1:],
         done=extract_sequences(self.samples.done, T_idxs, B_idxs, T),
         prev_action=all_action[:-1],
         prev_reward=all_reward[:-1],
     )
     if self.pixel_control_buffer is not None:
         pixctl_return = extract_sequences(
             self.pixel_control_buffer["return_"], T_idxs, B_idxs, T)
         batch = SamplesFromReplayPC(*batch, pixctl_return=pixctl_return)
     return torchify_buffer(batch)
Example #16
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi, rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     terminations = torch.bernoulli(beta).bool()  # Sample terminations
     dist_info_omega = DistInfo(prob=pi)
     new_o = self.sample_option(terminations, dist_info_omega)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi * q).sum(-1),
                                 termination=terminations,
                                 inter_option_dist_info=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Example #17
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_option_input = self._prev_option
     if prev_option_input is None:  # Hack to extract previous option
         prev_option_input = torch.full_like(prev_action, -1)
     prev_action = self.distribution.to_onehot(prev_action)
     prev_option_input = self.distribution_omega.to_onehot_with_invalid(
         prev_option_input)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, prev_option_input),
         device=self.device)
     pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs,
                                                   self.prev_rnn_state)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi_omega * q).sum(-1),
                                 termination=terminations,
                                 dist_info_omega=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Example #18
0
File: mtgail.py Project: qxcv/mtil
 def append_samples(self, samples):
     """Append samples drawn drawn from a sampler. Should be namedarraytuple
     with leading dimensions `(time_steps, batch_size)`."""
     replay_samples = DiscrimReplaySamples(
         all_observation=samples.env.observation,
         all_action=samples.agent.action)
     T, B = get_leading_dims(replay_samples, n_dim=2)
     # if there's not enough room for a single full round of sampling then
     # the buffer is _probably_ too small.
     assert T * B <= self.total_n_samples, \
         f"There's not enough room in this buffer for a single full " \
         f"batch! T*B={T*B} > total_n_samples={self.total_n_samples}"
     flat_samples = buffer_func(
         replay_samples, lambda t: t.reshape((T * B, ) + t.shape[2:]))
     n_copied = 0
     while n_copied < T * B:
         # only copy to the end
         n_to_copy = min(T * B - n_copied, self.total_n_samples - self.ptr)
         self.circ_buf[self.ptr:self.ptr + n_to_copy] \
             = flat_samples[n_copied:n_copied + n_to_copy]
         n_copied += n_to_copy
         self.ptr = (self.ptr + n_to_copy) % self.total_n_samples
         self.samples_in_buffer = min(self.total_n_samples,
                                      self.samples_in_buffer + n_to_copy)
Example #19
0
 def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward), device=self.device
     )
     pi, value, rnn_state, conv = self.model(*model_inputs, self.prev_rnn_state)
     if self._act_uniform:
         pi[:] = 1.0 / pi.shape[-1]  # uniform
     dist_info = DistInfo(prob=pi)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnnConv(
         dist_info=dist_info,
         value=value,
         prev_rnn_state=prev_rnn_state,
         conv=conv if self.store_latent else None,
     )  # Don't write the extra data.
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Example #20
0
 def extract_observation(self, T_idxs, B_idxs, T):
     return buffer_func(self.samples.observation, extract_sequences, T_idxs,
                        B_idxs, T)
Example #21
0
 def extract_observation(self, T_idxs, B_idxs, T):
     """Generalization anticipating frame-buffer."""
     return buffer_func(self.samples.observation, extract_sequences, T_idxs,
                        B_idxs, T)