def forward(self, observations, rnn_hidden_states, prev_actions, masks): r""" instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size] depth_embedding: [batch_size x DEPTH_ENCODER.output_size] rgb_embedding: [batch_size x RGB_ENCODER.output_size] """ ### instruction # instruction_embedding = self.instruction_encoder(observations) instruction_embedding = self._get_bert_embedding(observations) depth_embedding = self.depth_encoder(observations) rgb_embedding = self.rgb_encoder(observations) # print("depth_embedding: ", depth_embedding) # print("depth_embedding: ", depth_embedding.size()) if self.model_config.ablate_instruction: instruction_embedding = instruction_embedding * 0 if self.model_config.ablate_depth: depth_embedding = depth_embedding * 0 if self.model_config.ablate_rgb: rgb_embedding = rgb_embedding * 0 x = torch.cat([instruction_embedding, depth_embedding, rgb_embedding], dim=1) if self.model_config.SEQ2SEQ.use_prev_action: prev_actions_embedding = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().view(-1) ) x = torch.cat([x, prev_actions_embedding], dim=1) x, rnn_hidden_states = self.state_encoder(x, rnn_hidden_states, masks) if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active(): progress_hat = torch.tanh(self.progress_monitor(x)) progress_loss = F.mse_loss( progress_hat.squeeze(1), observations["progress"], reduction="none" ) AuxLosses.register_loss( "progress_monitor", progress_loss, self.model_config.PROGRESS_MONITOR.alpha, ) return x, rnn_hidden_states
def forward(self, observations, rnn_hidden_states, prev_actions, masks): r""" instruction_embedding: [batch_size x INSTRUCTION_ENCODER.output_size] depth_embedding: [batch_size x DEPTH_ENCODER.output_size] rgb_embedding: [batch_size x RGB_ENCODER.output_size] """ instruction_embedding = self.instruction_encoder(observations) depth_embedding = self.depth_encoder(observations) depth_embedding = torch.flatten(depth_embedding, 2) rgb_embedding = self.rgb_encoder(observations) rgb_embedding = torch.flatten(rgb_embedding, 2) prev_actions = self.prev_action_embedding( ((prev_actions.float() + 1) * masks).long().view(-1) ) if self.model_config.ablate_instruction: instruction_embedding = instruction_embedding * 0 if self.model_config.ablate_depth: depth_embedding = depth_embedding * 0 if self.model_config.ablate_rgb: rgb_embedding = rgb_embedding * 0 if self.rcm_state_encoder: ( state, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], ) = self.state_encoder( rgb_embedding, depth_embedding, prev_actions, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], masks, ) else: rgb_in = self.rgb_linear(rgb_embedding) depth_in = self.depth_linear(depth_embedding) state_in = torch.cat([rgb_in, depth_in, prev_actions], dim=1) ( state, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], ) = self.state_encoder( state_in, rnn_hidden_states[0 : self.state_encoder.num_recurrent_layers], masks, ) text_state_q = self.state_q(state) text_state_k = self.text_k(instruction_embedding) text_mask = (instruction_embedding == 0.0).all(dim=1) text_embedding = self._attn( text_state_q, text_state_k, instruction_embedding, text_mask ) rgb_k, rgb_v = torch.split( self.rgb_kv(rgb_embedding), self._hidden_size // 2, dim=1 ) depth_k, depth_v = torch.split( self.depth_kv(depth_embedding), self._hidden_size // 2, dim=1 ) text_q = self.text_q(text_embedding) rgb_embedding = self._attn(text_q, rgb_k, rgb_v) depth_embedding = self._attn(text_q, depth_k, depth_v) x = torch.cat( [state, text_embedding, rgb_embedding, depth_embedding, prev_actions], dim=1 ) x = self.second_state_compress(x) ( x, rnn_hidden_states[self.state_encoder.num_recurrent_layers :], ) = self.second_state_encoder( x, rnn_hidden_states[self.state_encoder.num_recurrent_layers :], masks ) if self.model_config.PROGRESS_MONITOR.use and AuxLosses.is_active(): progress_hat = torch.tanh(self.progress_monitor(x)) progress_loss = F.mse_loss( progress_hat.squeeze(1), observations["progress"], reduction="none" ) AuxLosses.register_loss( "progress_monitor", progress_loss, self.model_config.PROGRESS_MONITOR.alpha, ) return x, rnn_hidden_states