def step(self, observation, prev_action, prev_reward):
     prev_action = self.distribution.to_onehot(prev_action)
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     probs, value, rnn_state = self.model(*agent_inputs,
                                          self.prev_rnn_state)
     dist_info = DistInfo(prob=probs)
     if self._mode == 'sample':
         action = self.distribution.sample(dist_info)
     elif self._mode == 'eval':
         action = torch.argmax(probs, dim=-1)
     # Model handles None, but Buffer does not, make zeros if needed:
     if self.prev_rnn_state is None:
         prev_rnn_state = buffer_func(rnn_state, torch.zeros_like)
     else:
         prev_rnn_state = self.prev_rnn_state
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Ejemplo n.º 2
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs,
                                                self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info,
                               value=value,
                               prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Ejemplo n.º 3
0
 def step(self, observation, prev_action, prev_reward):
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
         device=self.device)
     mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     action = self.distribution.sample(dist_info)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoRnn(dist_info=dist_info, value=value,
         prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device="cpu")
     self.advance_rnn_state(rnn_state)  # Keep on device.
     return AgentStep(action=action, agent_info=agent_info)
Ejemplo n.º 4
0
    def step(self, observation, prev_action, prev_reward):
        prev_action = self.distribution.to_onehot(prev_action)
        agent_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)

        pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
        dist_info = DistInfo(prob=pi)

        if self.dual_model:
            int_pi, int_value, int_rnn_state = self.model_int(
                *agent_inputs, self.prev_int_rnn_state)
            dist_int_info = DistInfo(prob=int_pi)
            if self._mode == "eval":
                action = self.distribution.sample(dist_info)
            else:
                action = self.distribution.sample(dist_int_info)
        else:
            action = self.distribution.sample(dist_info)

        # Model handles None, but Buffer does not, make zeros if needed:
        prev_rnn_state = self.prev_rnn_state or buffer_func(
            rnn_state, torch.zeros_like)
        # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
        # (Special case: model should always leave B dimension in.)
        prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)

        if self.dual_model:
            prev_int_rnn_state = self.prev_int_rnn_state or buffer_func(
                int_rnn_state, torch.zeros_like)
            prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose",
                                               0, 1)
            agent_info = AgentInfoRnnTwin(
                dist_info=dist_info,
                value=value,
                prev_rnn_state=prev_rnn_state,
                dist_int_info=dist_int_info,
                int_value=int_value,
                prev_int_rnn_state=prev_int_rnn_state)
        else:
            agent_info = AgentInfoRnn(dist_info=dist_info,
                                      value=value,
                                      prev_rnn_state=prev_rnn_state)
        action, agent_info = buffer_to((action, agent_info), device="cpu")
        self.advance_rnn_state(rnn_state)  # Keep on device.
        if self.dual_model:
            self.advance_int_rnn_state(int_rnn_state)
        return AgentStep(action=action, agent_info=agent_info)