Пример #1
0
 def get_dist_and_value(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
 ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
     if self.use_lstm:
         # Use only the back half of memories for critic and actor
         actor_mem, critic_mem = torch.split(memories,
                                             self.memory_size // 2,
                                             dim=-1)
     else:
         critic_mem = None
         actor_mem = None
     dists, actor_mem_outs = self.get_dists(
         vec_inputs,
         vis_inputs,
         memories=actor_mem,
         sequence_length=sequence_length,
         masks=masks,
     )
     value_outputs, critic_mem_outs = self.critic(
         vec_inputs,
         vis_inputs,
         memories=critic_mem,
         sequence_length=sequence_length)
     if self.use_lstm:
         mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
     else:
         mem_out = None
     return dists, value_outputs, mem_out
Пример #2
0
 def get_action_stats_and_value(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
 ) -> Tuple[
     AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
 ]:
     if self.use_lstm:
         # Use only the back half of memories for critic and actor
         actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
     else:
         critic_mem = None
         actor_mem = None
     encoding, actor_mem_outs = self.network_body(
         vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length
     )
     action, log_probs, entropies = self.action_model(encoding, masks)
     value_outputs, critic_mem_outs = self.critic(
         vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
     )
     if self.use_lstm:
         mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
     else:
         mem_out = None
     return action, log_probs, entropies, value_outputs, mem_out
Пример #3
0
 def _get_actor_critic_mem(
     self,
     memories: Optional[torch.Tensor] = None
 ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
     if self.use_lstm and memories is not None:
         # Use only the back half of memories for critic and actor
         actor_mem, critic_mem = torch.split(memories,
                                             self.memory_size // 2,
                                             dim=-1)
     else:
         critic_mem = None
         actor_mem = None
     return actor_mem, critic_mem
Пример #4
0
 def critic_pass(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
 ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
     actor_mem, critic_mem = None, None
     if self.use_lstm:
         # Use only the back half of memories for critic
         actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
     value_outputs, critic_mem_out = self.critic(
         vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
     )
     if actor_mem is not None:
         # Make memories with the actor mem unchanged
         memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1)
     else:
         memories_out = None
     return value_outputs, memories_out