def get_dist_and_value( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: if self.use_lstm: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None dists, actor_mem_outs = self.get_dists( vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks, ) value_outputs, critic_mem_outs = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length) if self.use_lstm: mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return dists, value_outputs, mem_out
def get_action_stats_and_value( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[ AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor ]: if self.use_lstm: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None encoding, actor_mem_outs = self.network_body( vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length ) action, log_probs, entropies = self.action_model(encoding, masks) value_outputs, critic_mem_outs = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) if self.use_lstm: mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return action, log_probs, entropies, value_outputs, mem_out
def _get_actor_critic_mem( self, memories: Optional[torch.Tensor] = None ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if self.use_lstm and memories is not None: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None return actor_mem, critic_mem
def critic_pass( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: actor_mem, critic_mem = None, None if self.use_lstm: # Use only the back half of memories for critic actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) value_outputs, critic_mem_out = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) if actor_mem is not None: # Make memories with the actor mem unchanged memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) else: memories_out = None return value_outputs, memories_out