def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: target_encoding = self.get_target_coordinates_encoding(observations) x: Union[torch.Tensor, List[torch.Tensor]] x = [target_encoding] # if observations["rgb"].shape[0] != 1: # print("rgb", (observations["rgb"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["rgb"][...,0,0,:]).float().mean()) # if "depth" in observations: # print("depth", (observations["depth"][...,0,0,:].unsqueeze(-2).unsqueeze(-2) == observations["depth"][...,0,0,:]).float().mean()) if not self.is_blind: perception_embed = self.visual_encoder(observations) if self.sensor_fusion: perception_embed = self.sensor_fuser(perception_embed) x = [perception_embed] + x x = torch.cat(x, dim=-1) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) ac_output = ActorCriticOutput(distributions=self.actor(x), values=self.critic(x), extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: out = self.linear(cast(torch.Tensor, observations[self.input_uuid])) main_logits = out[..., :self.num_actions] aux_logits = out[..., self.num_actions:-1] values = out[..., -1:] # noinspection PyArgumentList return ( ActorCriticOutput( distributions=cast( DistributionType, CategoricalDistr( logits=main_logits)), # step x sampler x ... values=cast(torch.FloatTensor, values.view(values.shape[:2] + (-1, ))), # step x sampler x flattened extras={ "auxiliary_distributions": CategoricalDistr(logits=aux_logits) }, ), None, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: if not self.is_blind: perception_embed = self.visual_encoder(observations) else: # TODO manage blindness for all agents simultaneously or separate? raise NotImplementedError() # TODO alternative where all agents consume all observations x, rnn_hidden_states = self.state_encoder(perception_embed, memory.tensor("rnn"), masks) dists, vals = self.actor_critic(x) return ( ActorCriticOutput( distributions=dists, values=vals, extras={}, ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward(self, observations, memory, prev_actions, masks): dists, values = self.head(observations[self.input_uuid]) # noinspection PyArgumentList return ( ActorCriticOutput(distributions=dists, values=values, extras={},), None, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: """Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an `ActorCriticOutput` object containing the model's policy (distribution over actions) and evaluation of the current state (value). # Parameters observations : Batched input observations. memory : `Memory` containing the hidden states from initial timepoints. prev_actions : Tensor of previous actions taken. masks : Masks applied to hidden states. See `RNNStateEncoder`. # Returns Tuple of the `ActorCriticOutput` and recurrent hidden state. """ arm2obj_dist = self.get_relative_distance_embedding( observations["relative_agent_arm_to_obj"]) obj2goal_dist = self.get_relative_distance_embedding( observations["relative_obj_to_goal"]) perception_embed = self.visual_encoder(observations) pickup_bool = observations["pickedup_object"] before_pickup = pickup_bool == 0 # not used because of our initialization after_pickup = pickup_bool == 1 distances = arm2obj_dist distances[after_pickup] = obj2goal_dist[after_pickup] x = [distances, perception_embed] x_cat = torch.cat(x, dim=-1) x_out, rnn_hidden_states = self.state_encoder(x_cat, memory.tensor("rnn"), masks) actor_out = self.actor(x_out) critic_out = self.critic(x_out) actor_critic_output = ActorCriticOutput(distributions=actor_out, values=critic_out, extras={}) updated_memory = memory.set_tensor("rnn", rnn_hidden_states) return ( actor_critic_output, updated_memory, )
def forward(self, observations, memory, prev_actions, masks): out = self.linear(observations[self.input_uuid]) # noinspection PyArgumentList return ( ActorCriticOutput( # ensure [steps, samplers, ...] distributions=CategoricalDistr(logits=out[..., :-1]), # ensure [steps, samplers, flattened] values=cast(torch.FloatTensor, out[..., -1:].view(*out.shape[:2], -1)), extras={}, ), None, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: x = self.goal_visual_encoder(observations) x, rnn_hidden_states = self.state_encoder( x, memory.tensor(self.memory_key), masks) return ( ActorCriticOutput(distributions=self.actor(x), values=self.critic(x), extras={}), memory.set_tensor(self.memory_key, rnn_hidden_states), )
def forward( # type:ignore self, observations: Dict[str, Union[torch.FloatTensor, Dict[str, Any]]], memory: Memory, prev_actions: Any, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: means = self.actor(observations[self.input_uuid]) values = self.critic(observations[self.input_uuid]) return ( ActorCriticOutput( cast(DistributionType, GaussianDistr(loc=means, scale=self.action_std)), values, {}, ), None, # no Memory )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: x = self.goal_visual_encoder(observations) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) return ( ActorCriticOutput( distributions=self.actor(x), values=self.critic(x), extras={"auxiliary_distributions": self.auxiliary_actor(x)} if self.include_auxiliary_head else {}, ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: """Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an `ActorCriticOutput` object containing the model's policy (distribution over actions) and evaluation of the current state (value). # Parameters observations : Batched input observations. memory : `Memory` containing the hidden states from initial timepoints. prev_actions : Tensor of previous actions taken. masks : Masks applied to hidden states. See `RNNStateEncoder`. # Returns Tuple of the `ActorCriticOutput` and recurrent hidden state. """ target_encoding = self.get_object_type_encoding( cast(Dict[str, torch.FloatTensor], observations) ) x = [target_encoding] if not self.is_blind: perception_embed = self.visual_encoder(observations) x = [perception_embed] + x x_cat = torch.cat(x, dim=-1) # type: ignore x_out, rnn_hidden_states = self.state_encoder( x_cat, memory.tensor("rnn"), masks ) return ( ActorCriticOutput( distributions=self.actor(x_out), values=self.critic(x_out), extras={} ), memory.set_tensor("rnn", rnn_hidden_states), )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: cur_img = observations[self.rgb_uuid] unshuffled_img = observations[self.unshuffled_rgb_uuid] concat_img = torch.cat((cur_img, unshuffled_img), dim=-1) x = self.visual_encoder({self.concat_rgb_uuid: concat_img}) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) ac_output = ActorCriticOutput(distributions=self.actor(x), values=self.critic(x), extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def adapt_result(ac_output, hidden_states, num_steps, num_samplers, num_agents, num_layers, observations): # type: ignore distributions = CategoricalDistr( logits=ac_output.distributions.logits.view(num_steps, num_samplers, -1), ) values = ac_output.values.view(num_steps, num_samplers, num_agents) extras = ac_output.extras # ignore shape # TODO confirm the shape of the auxiliary distribution is the same as the actor's if "auxiliary_distributions" in extras: extras["auxiliary_distributions"] = CategoricalDistr( logits=extras["auxiliary_distributions"].logits.view( num_steps, num_samplers, -1 # assume single-agent ), ) hidden_states = hidden_states.view(num_layers, num_samplers * num_agents, -1) # Unflatten all observation batch dims def recursively_adapt_observations(obs): for entry in obs: if isinstance(obs[entry], Dict): recursively_adapt_observations(obs[entry]) else: assert isinstance(obs[entry], torch.Tensor) if entry in ["minigrid_ego_image", "minigrid_mission"]: final_dims = obs[entry].shape[ 1:] # assumes no agents dim in observations! obs[entry] = obs[entry].view(num_steps, num_samplers * num_agents, *final_dims) recursively_adapt_observations(observations) return ( ActorCriticOutput(distributions=distributions, values=values, extras=extras), hidden_states, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: out = self.linear(cast(torch.Tensor, observations[self.key])) assert len(out.shape) in [ 3, 4, ], "observations must be [step, sampler, data] or [step, sampler, agent, data]" if len(out.shape) == 3: # [step, sampler, data] -> [step, sampler, agent, data] out = out.unsqueeze(-2) main_logits = out[..., :self.num_actions] aux_logits = out[..., self.num_actions:-1] values = out[..., -1:] # noinspection PyArgumentList return ( ActorCriticOutput( distributions=cast( DistributionType, CategoricalDistr( logits=main_logits)), # step x sampler x ... values=cast(torch.FloatTensor, values.view(values.shape[:2] + (-1, ))), # step x sampler x flattened extras={ "auxiliary_distributions": CategoricalDistr(logits=aux_logits), }, ), None, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: cur_img_resnet = observations[self.rgb_uuid] unshuffled_img_resnet = observations[self.unshuffled_rgb_uuid] concat_img = torch.cat( ( cur_img_resnet, unshuffled_img_resnet, cur_img_resnet * unshuffled_img_resnet, ), dim=-3, ) batch_shape, features_shape = concat_img.shape[:-3], concat_img.shape[ -3:] concat_img_reshaped = concat_img.view(-1, *features_shape) attention_probs = torch.softmax( self.visual_attention(concat_img_reshaped).view( concat_img_reshaped.shape[0], -1), dim=-1, ).view(concat_img_reshaped.shape[0], 1, *concat_img_reshaped.shape[-2:]) x = ((self.visual_encoder(concat_img_reshaped) * attention_probs).mean(-1).mean(-1)) x = x.view(*batch_shape, -1) x, rnn_hidden_states = self.state_encoder(x, memory.tensor("rnn"), masks) ac_output = ActorCriticOutput(distributions=self.actor(x), values=self.critic(x), extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward( self, observations: ObservationType, recurrent_hidden_states: torch.FloatTensor, prev_actions: torch.Tensor, masks: torch.FloatTensor, ): ( observations, recurrent_hidden_states, prev_actions, masks, num_steps, num_samplers, num_agents, num_layers, ) = self.adapt_inputs(observations, recurrent_hidden_states, prev_actions, masks) if self.lang_model != "gru": ac_output, hidden_states = self.forward_loop( observations=observations, recurrent_hidden_states=recurrent_hidden_states, prev_actions=prev_actions, masks=masks, # type: ignore ) return self.adapt_result( ac_output, hidden_states[-1:], num_steps, num_samplers, num_agents, num_layers, observations, ) assert recurrent_hidden_states.shape[0] == 1 images = cast(torch.FloatTensor, observations["minigrid_ego_image"]) if self.use_cnn2: images_shape = images.shape # noinspection PyArgumentList images = images + torch.LongTensor( [0, 11, 22]).view( # type:ignore 1, 1, 1, 3).to(images.device) images = self.semantic_embedding(images).view( # type:ignore *images_shape[:3], 24) images = images.permute(0, 3, 1, 2).float() # type:ignore _, nsamplers, _ = recurrent_hidden_states.shape rollouts_len = images.shape[0] // nsamplers masks = cast(torch.FloatTensor, masks.view(rollouts_len, nsamplers, *masks.shape[1:])) instrs: Optional[torch.Tensor] = None if "minigrid_mission" in observations and self.use_instr: instrs = cast(torch.FloatTensor, observations["minigrid_mission"]) instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1]) needs_instr_reset_mask = masks != 1.0 needs_instr_reset_mask[0] = 1 needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1) blocking_inds: List[int] = np.where( needs_instr_reset_mask.view(rollouts_len, -1).any(-1).cpu().numpy())[0].tolist() blocking_inds.append(rollouts_len) instr_embeddings: Optional[torch.Tensor] = None if self.use_instr: instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip( *np.where(needs_instr_reset_mask.cpu().numpy()))) time_ind_to_which_need_instr_reset: List[List] = [ [] for _ in range(rollouts_len) ] reset_multi_ind_to_index = { mi: i for i, mi in enumerate(instr_reset_multi_inds) } for a, b in instr_reset_multi_inds: time_ind_to_which_need_instr_reset[a].append(b) unique_instr_embeddings = self._get_instr_embedding( instrs[needs_instr_reset_mask]) instr_embeddings_list = [unique_instr_embeddings[:nsamplers]] current_instr_embeddings_list = list(instr_embeddings_list[-1]) for time_ind in range(1, rollouts_len): if len(time_ind_to_which_need_instr_reset[time_ind]) == 0: instr_embeddings_list.append(instr_embeddings_list[-1]) else: for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[ time_ind]: current_instr_embeddings_list[ sampler_needing_reset_ind] = unique_instr_embeddings[ reset_multi_ind_to_index[( time_ind, sampler_needing_reset_ind)]] instr_embeddings_list.append( torch.stack(current_instr_embeddings_list, dim=0)) instr_embeddings = torch.stack(instr_embeddings_list, dim=0) # The following code can be used to compute the instr_embeddings in another way # and thus verify that the above logic is (more likely to be) correct # needs_instr_reset_mask = (masks != 1.0) # needs_instr_reset_mask[0] *= 0 # needs_instr_reset_inds = needs_instr_reset_mask.view(nrollouts, -1).any(-1).cpu().numpy() # # # Get inds where a new task has started # blocking_inds: List[int] = np.where(needs_instr_reset_inds)[0].tolist() # blocking_inds.append(needs_instr_reset_inds.shape[0]) # if nrollouts != 1: # pdb.set_trace() # if blocking_inds[0] != 0: # blocking_inds.insert(0, 0) # if self.use_instr: # instr_embeddings_list = [] # for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]): # instr_embeddings_list.append( # self._get_instr_embedding(instrs[ind0]) # .unsqueeze(0) # .repeat(ind1 - ind0, 1, 1) # ) # tmp_instr_embeddings = torch.cat(instr_embeddings_list, dim=0) # assert (instr_embeddings - tmp_instr_embeddings).abs().max().item() < 1e-6 # Embed images # images = images.view(nrollouts, nsamplers, *images.shape[1:]) image_embeddings = self.image_conv(images) if self.arch.startswith("expert_filmcnn"): instr_embeddings_flatter = instr_embeddings.view( -1, *instr_embeddings.shape[2:]) for controller in self.controllers: image_embeddings = controller(image_embeddings, instr_embeddings_flatter) image_embeddings = F.relu(self.film_pool(image_embeddings)) image_embeddings = image_embeddings.view(rollouts_len, nsamplers, -1) if self.use_instr and self.lang_model == "attgru": raise NotImplementedError("Currently attgru is not implemented.") memory = None if self.use_memory: assert recurrent_hidden_states.shape[0] == 1 hidden = ( recurrent_hidden_states[:, :, :self.semi_memory_size], recurrent_hidden_states[:, :, self.semi_memory_size:], ) embeddings_list = [] for ind0, ind1 in zip(blocking_inds[:-1], blocking_inds[1:]): hidden = (hidden[0] * masks[ind0], hidden[1] * masks[ind0]) rnn_out, hidden = self.memory_rnn(image_embeddings[ind0:ind1], hidden) embeddings_list.append(rnn_out) # embedding = hidden[0] embedding = torch.cat(embeddings_list, dim=0) memory = torch.cat(hidden, dim=-1) else: embedding = image_embeddings if self.use_instr and not "filmcnn" in self.arch: embedding = torch.cat((embedding, instr_embeddings), dim=-1) if hasattr(self, "aux_info") and self.aux_info: extra_predictions = { info: self.extra_heads[info](embedding) for info in self.extra_heads } else: extra_predictions = dict() embedding = embedding.view(rollouts_len * nsamplers, -1) ac_output = ActorCriticOutput( distributions=CategoricalDistr(logits=self.actor(embedding), ), values=self.critic(embedding), extras=extra_predictions if not self.include_auxiliary_head else { **extra_predictions, "auxiliary_distributions": CategoricalDistr(logits=self.aux(embedding)), }, ) hidden_states = memory return self.adapt_result( ac_output, hidden_states, num_steps, num_samplers, num_agents, num_layers, observations, )
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: """Processes input batched observations to produce new actor and critic values. Processes input batched observations (along with prior hidden states, previous actions, and masks denoting which recurrent hidden states should be masked) and returns an `ActorCriticOutput` object containing the model's policy (distribution over actions) and evaluation of the current state (value). # Parameters observations : Batched input observations. memory : `Memory` containing the hidden states from initial timepoints. prev_actions : Tensor of previous actions taken. masks : Masks applied to hidden states. See `RNNStateEncoder`. # Returns Tuple of the `ActorCriticOutput` and recurrent hidden state. """ # 1.1 use perception model (i.e. encoder) to get observation embeddings obs_embeds = self.forward_encoder(observations) # 1.2 use embedding model to get prev_action embeddings prev_actions_embeds = self.prev_action_embedder(prev_actions) joint_embeds = torch.cat((obs_embeds, prev_actions_embeds), dim=-1) # (T, N, *) # 2. use RNNs to get single/multiple beliefs beliefs_dict = {} for key, model in self.state_encoders.items(): beliefs_dict[key], rnn_hidden_states = model( joint_embeds, memory.tensor(key), masks) memory.set_tensor(key, rnn_hidden_states) # update memory here # 3. fuse beliefs for multiple belief models beliefs, task_weights = self.fuse_beliefs(beliefs_dict, obs_embeds) # fused beliefs # 4. prepare output extras = ({ aux_uuid: { "beliefs": (beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs), "obs_embeds": obs_embeds, "aux_model": (self.aux_models[aux_uuid] if aux_uuid in self.aux_models else None), } for aux_uuid in self.auxiliary_uuids } if self.auxiliary_uuids is not None else {}) if self.multiple_beliefs: extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights actor_critic_output = ActorCriticOutput( distributions=self.actor(beliefs), values=self.critic(beliefs), extras=extras, ) return actor_critic_output, memory
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: in_walkthrough_phase_mask = observations[ self.in_walkthrough_phase_uuid] in_unshuffle_phase_mask = ~in_walkthrough_phase_mask in_walkthrough_float = in_walkthrough_phase_mask.float() in_unshuffle_float = in_unshuffle_phase_mask.float() # Don't reset hidden state at start of the unshuffle task masks_no_unshuffle_reset = (masks.bool() | in_unshuffle_phase_mask).float() masks_with_unshuffle_reset = masks.float() del masks # Just to make sure we don't accidentally use `masks when we want `masks_no_unshuffle_reset` # Visual features cur_img_resnet = observations[self.rgb_uuid] unshuffled_img_resnet = observations[self.unshuffled_rgb_uuid] concat_img = torch.cat( ( cur_img_resnet, unshuffled_img_resnet, cur_img_resnet * unshuffled_img_resnet, ), dim=-3, ) batch_shape, features_shape = concat_img.shape[:-3], concat_img.shape[ -3:] concat_img_reshaped = concat_img.view(-1, *features_shape) attention_probs = torch.softmax( self.visual_attention(concat_img_reshaped).view( concat_img_reshaped.shape[0], -1), dim=-1, ).view(concat_img_reshaped.shape[0], 1, *concat_img_reshaped.shape[-2:]) vis_features = ((self.visual_encoder(concat_img_reshaped) * attention_probs).mean(-1).mean(-1)) vis_features = vis_features.view(*batch_shape, -1) # Various embeddings prev_action_embeddings = self.prev_action_embedder( ((~masks_with_unshuffle_reset.bool()).long() * (prev_actions.unsqueeze(-1) + 1))).squeeze(-2) is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder( in_walkthrough_phase_mask.long()).squeeze(-2) to_cat = [ vis_features, prev_action_embeddings, is_walkthrough_phase_embedding, ] rnn_hidden_states = memory.tensor("rnn") rnn_outs = [] obs_for_rnn = torch.cat(to_cat, dim=-1) last_walkthrough_encoding = memory.tensor("walkthrough_encoding") for step in range(masks_with_unshuffle_reset.shape[0]): rnn_out, rnn_hidden_states = self.state_encoder( torch.cat( ( obs_for_rnn[step:step + 1], last_walkthrough_encoding * masks_no_unshuffle_reset[step:step + 1], ), dim=-1, ), rnn_hidden_states, masks_with_unshuffle_reset[step:step + 1], ) rnn_outs.append(rnn_out) walkthrough_encoding, _ = self.walkthrough_encoder( rnn_out, last_walkthrough_encoding, masks_no_unshuffle_reset[step:step + 1], ) last_walkthrough_encoding = ( last_walkthrough_encoding * in_unshuffle_float[step:step + 1] + walkthrough_encoding * in_walkthrough_float[step:step + 1]) memory = memory.set_tensor("walkthrough_encoding", last_walkthrough_encoding) rnn_out = torch.cat(rnn_outs, dim=0) walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out) unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out) assert len(in_walkthrough_float.shape) == len( walkthrough_dist.logits.shape) if self.walkthrough_good_action_logits is not None: walkthrough_logits = ( walkthrough_dist.logits + self.walkthrough_good_action_logits.view( *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1)) else: walkthrough_logits = walkthrough_dist.logits actor = CategoricalDistr( logits=in_walkthrough_float * walkthrough_logits + in_unshuffle_float * unshuffle_dist.logits) values = (in_walkthrough_float * walkthrough_vals + in_unshuffle_float * unshuffle_vals) ac_output = ActorCriticOutput(distributions=actor, values=values, extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward( # type:ignore self, observations: ObservationType, memory: Memory, prev_actions: torch.Tensor, masks: torch.FloatTensor, ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]: in_walkthrough_phase_mask = observations[ self.in_walkthrough_phase_uuid] in_unshuffle_phase_mask = ~in_walkthrough_phase_mask in_walkthrough_float = in_walkthrough_phase_mask.float() in_unshuffle_float = in_unshuffle_phase_mask.float() # Don't reset hidden state at start of the unshuffle task masks_no_unshuffle_reset = (masks.bool() | in_unshuffle_phase_mask).float() cur_img = observations[self.rgb_uuid] unshuffled_img = observations[self.unshuffled_rgb_uuid] concat_img = torch.cat((cur_img, unshuffled_img), dim=-1) # Various embeddings vis_features = self.visual_encoder({self.concat_rgb_uuid: concat_img}) prev_action_embeddings = self.prev_action_embedder( ((~masks.bool()).long() * (prev_actions.unsqueeze(-1) + 1))).squeeze(-2) is_walkthrough_phase_embedding = self.is_walkthrough_phase_embedder( in_walkthrough_phase_mask.long()).squeeze(-2) to_cat = [ vis_features, prev_action_embeddings, is_walkthrough_phase_embedding, ] rnn_hidden_states = memory.tensor("rnn") rnn_outs = [] obs_for_rnn = torch.cat(to_cat, dim=-1) last_walkthrough_encoding = memory.tensor("walkthrough_encoding") for step in range(masks.shape[0]): rnn_out, rnn_hidden_states = self.state_encoder( torch.cat( (obs_for_rnn[step:step + 1], last_walkthrough_encoding), dim=-1), rnn_hidden_states, masks[step:step + 1], ) rnn_outs.append(rnn_out) walkthrough_encoding, _ = self.walkthrough_encoder( rnn_out, last_walkthrough_encoding, masks_no_unshuffle_reset[step:step + 1], ) last_walkthrough_encoding = ( last_walkthrough_encoding * in_unshuffle_float[step:step + 1] + walkthrough_encoding * in_walkthrough_float[step:step + 1]) memory = memory.set_tensor("walkthrough_encoding", last_walkthrough_encoding) rnn_out = torch.cat(rnn_outs, dim=0) walkthrough_dist, walkthrough_vals = self.walkthrough_ac(rnn_out) unshuffle_dist, unshuffle_vals = self.unshuffle_ac(rnn_out) assert len(in_walkthrough_float.shape) == len( walkthrough_dist.logits.shape) if self.walkthrough_good_action_logits is not None: walkthrough_logits = ( walkthrough_dist.logits + self.walkthrough_good_action_logits.view( *((1, ) * (len(walkthrough_dist.logits.shape) - 1)), -1)) else: walkthrough_logits = walkthrough_dist.logits actor = CategoricalDistr( logits=in_walkthrough_float * walkthrough_logits + in_unshuffle_float * unshuffle_dist.logits) values = (in_walkthrough_float * walkthrough_vals + in_unshuffle_float * unshuffle_vals) ac_output = ActorCriticOutput(distributions=actor, values=values, extras={}) return ac_output, memory.set_tensor("rnn", rnn_hidden_states)
def forward_loop( self, observations: ObservationType, recurrent_hidden_states: torch.FloatTensor, prev_actions: torch.Tensor, masks: torch.FloatTensor, ): results = [] images = cast(torch.FloatTensor, observations["minigrid_ego_image"]).float() instrs: Optional[torch.Tensor] = None if "minigrid_mission" in observations: instrs = cast(torch.Tensor, observations["minigrid_mission"]) _, nsamplers, _ = recurrent_hidden_states.shape rollouts_len = images.shape[0] // nsamplers obs = babyai.rl.DictList() images = images.view(rollouts_len, nsamplers, *images.shape[1:]) masks = masks.view(rollouts_len, nsamplers, *masks.shape[1:]) # type:ignore # needs_reset = (masks != 1.0).view(nrollouts, -1).any(-1) if instrs is not None: instrs = instrs.view(rollouts_len, nsamplers, instrs.shape[-1]) needs_instr_reset_mask = masks != 1.0 needs_instr_reset_mask[0] = 1 needs_instr_reset_mask = needs_instr_reset_mask.squeeze(-1) instr_embeddings: Optional[torch.Tensor] = None if self.use_instr: instr_reset_multi_inds = list((int(a), int(b)) for a, b in zip( *np.where(needs_instr_reset_mask.cpu().numpy()))) time_ind_to_which_need_instr_reset: List[List] = [ [] for _ in range(rollouts_len) ] reset_multi_ind_to_index = { mi: i for i, mi in enumerate(instr_reset_multi_inds) } for a, b in instr_reset_multi_inds: time_ind_to_which_need_instr_reset[a].append(b) unique_instr_embeddings = self._get_instr_embedding( instrs[needs_instr_reset_mask]) instr_embeddings_list = [unique_instr_embeddings[:nsamplers]] current_instr_embeddings_list = list(instr_embeddings_list[-1]) for time_ind in range(1, rollouts_len): if len(time_ind_to_which_need_instr_reset[time_ind]) == 0: instr_embeddings_list.append(instr_embeddings_list[-1]) else: for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[ time_ind]: current_instr_embeddings_list[ sampler_needing_reset_ind] = unique_instr_embeddings[ reset_multi_ind_to_index[( time_ind, sampler_needing_reset_ind)]] instr_embeddings_list.append( torch.stack(current_instr_embeddings_list, dim=0)) instr_embeddings = torch.stack(instr_embeddings_list, dim=0) assert recurrent_hidden_states.shape[0] == 1 memory = recurrent_hidden_states[0] # instr_embedding: Optional[torch.Tensor] = None for i in range(rollouts_len): obs.image = images[i] if "minigrid_mission" in observations: obs.instr = instrs[i] # reset = needs_reset[i].item() # if self.baby_ai_model.use_instr and (reset or i == 0): # instr_embedding = self.baby_ai_model._get_instr_embedding(obs.instr) results.append( self.forward_once(obs, memory=memory * masks[i], instr_embedding=instr_embeddings[i])) memory = results[-1]["memory"] embedding = torch.cat([r["embedding"] for r in results], dim=0) extra_predictions_list = [r["extra_predictions"] for r in results] extra_predictions = { key: torch.cat([ep[key] for ep in extra_predictions_list], dim=0) for key in extra_predictions_list[0] } return ( ActorCriticOutput( distributions=CategoricalDistr(logits=self.actor(embedding), ), values=self.critic(embedding), extras=extra_predictions if not self.include_auxiliary_head else { **extra_predictions, "auxiliary_distributions": cast(Any, CategoricalDistr(logits=self.aux(embedding))), }, ), torch.stack([r["memory"] for r in results], dim=0), )