def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: target_encoding = self.get_target_coordinates_encoding(observations) x: Union[torch.Tensor, List[torch.Tensor]] x = [target_encoding] # if observations["rgb"].shape[0] != 1: # print("rgb", (observations["rgb"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["rgb"][...,0,0,:]).float().mean()) # if "depth" in observations: # print("depth", (observations["depth"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["depth"][...,0,0,:]).float().mean()) if not self.is_blind: perception_embed = self.visual_encoder(observations) if self.sensor_fusion: perception_embed = self.sensor_fuser(perception_embed) x = [perception_embed] + x x = torch.cat(x, dim=-1) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) ac_output = ActorCriticOutput( distributions=self.actor(x), values=self.critic(x), extras={} ) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: if not self.is_blind: perception_embed = self.visual_encoder(observations) else: # TODO manage blindness for all agents simultaneously or separate? raise NotImplementedError() # TODO alternative where all agents consume all observations x, rnn_hidden_states = self.state_encoder(perception_embed, memory.tensor("rnn"), masks) dists, vals = self.actor_critic(x) return ( ActorCriticOutput( distributions=dists, values=vals, extras={}, ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: out = self.linear(cast(torch.Tensor, observations[self.key])) assert len(out.shape) in [ 3, 4, ], "observations must be [step, sampler, data] or [step, sampler, agent, data]" if len(out.shape) == 3: # [step, sampler, data] -> [step, sampler, agent, data] out = out.unsqueeze(-2) main_logits = out[..., :self.num_actions] aux_logits = out[..., self.num_actions:-1] values = out[..., -1:] # noinspection PyArgumentList return ( ActorCriticOutput( distributions=CategoricalDistr(logits=main_logits), values=typing.cast(torch.FloatTensor, values), extras={ "auxiliary_distributions": CategoricalDistr(logits=aux_logits) }, ), None, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: target_encoding = self.get_target_coordinates_encoding(observations) x: Union[torch.Tensor, List[torch.Tensor]] x = [target_encoding] if not self.is_blind: perception_embed = self.visual_encoder(observations) if self.sensor_fusion: perception_embed = self.sensor_fuser(perception_embed) x = [perception_embed] + x x = torch.cat(x, dim=-1) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) ac_output = ActorCriticOutput(distributions=self.actor(x), values=self.critic(x), extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: x = self.goal_visual_encoder(observations) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) return ( ActorCriticOutput( distributions=self.actor(x), values=self.critic(x), extras={} ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: """Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an `ActorCriticOutput` object containing the model's policy (distribution over actions) and evaluation of the current state (value). # Parameters observations : Batched input observations. rnn_hidden_states : Hidden states from initial timepoints. prev_actions : Tensor of previous actions taken. masks : Masks applied to hidden states. See `RNNStateEncoder`. # Returns Tuple of the `ActorCriticOutput` and recurrent hidden state. """ target_encoding = self.get_object_type_encoding( cast(Dict[str, torch.FloatTensor], observations) ) x = [target_encoding] if not self.is_blind: perception_embed = self.visual_encoder(observations) x = [perception_embed] + x x_cat = torch.cat(x, dim=-1) # type: ignore x_out, rnn_hidden_states = self.state_encoder( x_cat, memory.tensor("rnn"), masks ) return ( ActorCriticOutput( distributions=self.actor(x_out), values=self.critic(x_out), extras={} ), memory.set_tensor("rnn", rnn_hidden_states), )
def adapt_result(ac_output, hidden_states, num_steps, num_samplers, num_agents, num_layers, observations): # type: ignore distributions = CategoricalDistr( logits=ac_output.distributions.logits.view(num_steps, num_samplers, num_agents, -1), ) values = ac_output.values.view(num_steps, num_samplers, num_agents, 1) extras = ac_output.extras # ignore shape # TODO confirm the shape of the auxiliary distribution is the same as the actor's if "auxiliary_distributions" in extras: extras["auxiliary_distributions"] = CategoricalDistr( logits=extras["auxiliary_distributions"].logits.view( num_steps, num_samplers, num_agents, -1), ) hidden_states = hidden_states.view(num_layers, num_samplers * num_agents, -1) # Unflatten all observation batch dims def recursively_adapt_observations(obs, num_steps, num_samplers, num_agents): for entry in obs: if isinstance(obs[entry], Dict): recursively_adapt_observations(obs[entry], num_steps, num_samplers, num_agents) else: assert isinstance(obs[entry], torch.Tensor) if entry in ["minigrid_ego_image", "minigrid_mission"]: final_dims = obs[entry].shape[ 1:] # assumes no agents dim in observations! obs[entry] = obs[entry].view(num_steps, num_samplers * num_agents, *final_dims) recursively_adapt_observations(observations, num_steps, num_samplers, num_agents) return ( ActorCriticOutput(distributions=distributions, values=values, extras=extras), hidden_states, )
def forward(self, observations, memory, prev_actions, masks): out = self.linear(observations[self.input_uuid]) assert len(out.shape) in [ 3, 4, ], "observations must be [step, sampler, data] or [step, sampler, agent, data]" if len(out.shape) == 3: # [step, sampler, data] -> [step, sampler, agent, data] out = out.unsqueeze(-2) # noinspection PyArgumentList return ( ActorCriticOutput( distributions=CategoricalDistr(logits=out[..., :-1]), values=cast(torch.FloatTensor, out[..., -1:]), extras={}, ), None, )
def forward( self, observations: ObservationType, recurrent_hidden_states: torch.FloatTensor, prev_actions: torch.Tensor, masks: torch.FloatTensor, ): ( observations, recurrent_hidden_states, prev_actions, masks, num_steps, num_samplers, num_agents, num_layers, ) = self.adapt_inputs(observations, recurrent_hidden_states, prev_actions, masks) if self.lang_model != "gru": ac_output, hidden_states = self.forward_loop( observations=observations, recurrent_hidden_states=recurrent_hidden_states, prev_actions=prev_actions, masks=masks, ) return self.adapt_result( ac_output, hidden_states[-1:], num_steps, num_samplers, num_agents, num_layers, observations, ) assert recurrent_hidden_states.shape[0] == 1 images = cast(torch.FloatTensor, observations["minigrid_ego_image"]) if self.use_cnn2: images_shape = images.shape images = images + torch.LongTensor( [0, 11, 22]).view( # type:ignore 1, 1, 1, 3).to(images.device) images = self.semantic_embedding(images).view( # type:ignore *images_shape[:3], 24) images = images.permute(0, 3, 1, 2).float() # type:ignore _, nsamplers, _ = recurrent_hidden_states.shape rollouts_len = images.shape[0] // nsamplers masks = typing.cast( torch.FloatTensor, masks.view(rollouts_len, nsamplers, *masks.shape[1:])) instrs: Optional[torch.Tensor] = None if "minigrid_mission" in observations and self.use_instr: instrs = cast(torch.FloatTensor, observations["minigrid_mission"]) instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1]) needs_instr_reset_mask = masks != 1.0 needs_instr_reset_mask[0] = 1 needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1) blocking_inds: List[int] = np.where( needs_instr_reset_mask.view(rollouts_len, -1).any(-1).cpu().numpy())[0].tolist() blocking_inds.append(rollouts_len) instr_embeddings: Optional[torch.Tensor] = None if self.use_instr: instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip( *np.where(needs_instr_reset_mask.cpu().numpy()))) time_ind_to_which_need_instr_reset: List[List] = [ [] for _ in range(rollouts_len) ] reset_multi_ind_to_index = { mi: i for i, mi in enumerate(instr_reset_multi_inds) } for a, b in instr_reset_multi_inds: time_ind_to_which_need_instr_reset[a].append(b) unique_instr_embeddings = self._get_instr_embedding( instrs[needs_instr_reset_mask]) instr_embeddings_list = [] instr_embeddings_list.append(unique_instr_embeddings[:nsamplers]) current_instr_embeddings_list = list(instr_embeddings_list[-1]) for time_ind in range(1, rollouts_len): if len(time_ind_to_which_need_instr_reset[time_ind]) == 0: instr_embeddings_list.append(instr_embeddings_list[-1]) else: for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[ time_ind]: current_instr_embeddings_list[ sampler_needing_reset_ind] = unique_instr_embeddings[ reset_multi_ind_to_index[( time_ind, sampler_needing_reset_ind)]] instr_embeddings_list.append( torch.stack(current_instr_embeddings_list, dim=0)) instr_embeddings = torch.stack(instr_embeddings_list, dim=0) # The following code can be used to compute the instr_embeddings in another way # and thus verify that the above logic is (more likely to be) correct # needs_instr_reset_mask = (masks != 1.0) # needs_instr_reset_mask[0] *= 0 # needs_instr_reset_inds = needs_instr_reset_mask.view(nrollouts, -1).any(-1).cpu().numpy() # # # Get inds where a new task has started # blocking_inds: List[int] = np.where(needs_instr_reset_inds)[0].tolist() # blocking_inds.append(needs_instr_reset_inds.shape[0]) # if nrollouts != 1: # pdb.set_trace() # if blocking_inds[0] != 0: # blocking_inds.insert(0, 0) # if self.use_instr: # instr_embeddings_list = [] # for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]): # instr_embeddings_list.append( # self._get_instr_embedding(instrs[ind0]) # .unsqueeze(0) # .repeat(ind1 - ind0, 1, 1) # ) # tmp_instr_embeddings = torch.cat(instr_embeddings_list, dim=0) # assert (instr_embeddings - tmp_instr_embeddings).abs().max().item() < 1e-6 # Embed images # images = images.view(nrollouts, nsamplers, *images.shape[1:]) image_embeddings = self.image_conv(images) if self.arch.startswith("expert_filmcnn"): instr_embeddings_flatter = instr_embeddings.view( -1, *instr_embeddings.shape[2:]) for controller in self.controllers: image_embeddings = controller(image_embeddings, instr_embeddings_flatter) image_embeddings = F.relu(self.film_pool(image_embeddings)) image_embeddings = image_embeddings.view(rollouts_len, nsamplers, -1) if self.use_instr and self.lang_model == "attgru": raise NotImplementedError("Currently attgru is not implemented.") memory = None if self.use_memory: assert recurrent_hidden_states.shape[0] == 1 hidden = ( recurrent_hidden_states[:, :, :self.semi_memory_size], recurrent_hidden_states[:, :, self.semi_memory_size:], ) embeddings_list = [] for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]): hidden = (hidden[0] * masks[ind0], hidden[1] * masks[ind0]) rnn_out, hidden = self.memory_rnn(image_embeddings[ind0:ind1], hidden) embeddings_list.append(rnn_out) # embedding = hidden[0] embedding = torch.cat(embeddings_list, dim=0) memory = torch.cat(hidden, dim=-1) else: embedding = image_embeddings if self.use_instr and not "filmcnn" in self.arch: embedding = torch.cat((embedding, instr_embeddings), dim=-1) if hasattr(self, "aux_info") and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() embedding = embedding.view(rollouts_len * nsamplers, -1) ac_output = ActorCriticOutput( distributions=CategoricalDistr(logits=self.actor(embedding), ), values=self.critic(embedding), extras=extra_predictions if not self.include_auxiliary_head else { **extra_predictions, "auxiliary_distributions": CategoricalDistr(logits=self.aux(embedding)), }, ) hidden_states = memory return self.adapt_result( ac_output, hidden_states, num_steps, num_samplers, num_agents, num_layers, observations, )
def forward_loop( self, observations: ObservationType, recurrent_hidden_states: torch.FloatTensor, prev_actions: torch.Tensor, masks: torch.FloatTensor, ): results = [] images = cast(torch.FloatTensor, observations["minigrid_ego_image"]).float() instrs: Optional[torch.Tensor] = None if "minigrid_mission" in observations: instrs = cast(torch.Tensor, observations["minigrid_mission"]) _, nsamplers, _ = recurrent_hidden_states.shape rollouts_len = images.shape[0] // nsamplers obs = babyai.rl.DictList() images = images.view(rollouts_len, nsamplers, *images.shape[1:]) masks = masks.view(rollouts_len, nsamplers, *masks.shape[1:]) # type:ignore # needs_reset = (masks != 1.0).view(nrollouts, -1).any(-1) if instrs is not None: instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1]) needs_instr_reset_mask = masks != 1.0 needs_instr_reset_mask[0] = 1 needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1) instr_embeddings: Optional[torch.Tensor] = None if self.use_instr: instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip( *np.where(needs_instr_reset_mask.cpu().numpy()))) time_ind_to_which_need_instr_reset: List[List] = [ [] for _ in range(rollouts_len) ] reset_multi_ind_to_index = { mi: i for i, mi in enumerate(instr_reset_multi_inds) } for a, b in instr_reset_multi_inds: time_ind_to_which_need_instr_reset[a].append(b) unique_instr_embeddings = self._get_instr_embedding( instrs[needs_instr_reset_mask]) instr_embeddings_list = [] instr_embeddings_list.append(unique_instr_embeddings[:nsamplers]) current_instr_embeddings_list = list(instr_embeddings_list[-1]) for time_ind in range(1, rollouts_len): if len(time_ind_to_which_need_instr_reset[time_ind]) == 0: instr_embeddings_list.append(instr_embeddings_list[-1]) else: for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[ time_ind]: current_instr_embeddings_list[ sampler_needing_reset_ind] = unique_instr_embeddings[ reset_multi_ind_to_index[( time_ind, sampler_needing_reset_ind)]] instr_embeddings_list.append( torch.stack(current_instr_embeddings_list, dim=0)) instr_embeddings = torch.stack(instr_embeddings_list, dim=0) assert recurrent_hidden_states.shape[0] == 1 memory = recurrent_hidden_states[0] # instr_embedding: Optional[torch.Tensor] = None for i in range(rollouts_len): obs.image = images[i] if "minigrid_mission" in observations: obs.instr = instrs[i] # reset = needs_reset[i].item() # if self.baby_ai_model.use_instr and (reset or i == 0): # instr_embedding = self.baby_ai_model._get_instr_embedding(obs.instr) results.append( self.forward_once(obs, memory=memory * masks[i], instr_embedding=instr_embeddings[i])) memory = results[-1]["memory"] embedding = torch.cat([r["embedding"] for r in results], dim=0) extra_predictions_list = [r["extra_predictions"] for r in results] extra_predictions = { key: torch.cat([ep[key] for ep in extra_predictions_list], dim=0) for key in extra_predictions_list[0] } return ( ActorCriticOutput( distributions=CategoricalDistr(logits=self.actor(embedding), ), values=self.critic(embedding), extras=extra_predictions if not self.include_auxiliary_head else { **extra_predictions, "auxiliary_distributions": CategoricalDistr(logits=self.aux(embedding)), }, ), torch.stack([r["memory"] for r in results], dim=0), )