def forward(self, input: rlt.PreprocessedRankingInput): input = self._convert_seq2slate_to_reward_model_format(input) state, src_seq, tgt_in_seq, src_src_mask, tgt_tgt_mask = ( input.state.float_features, input.src_seq.float_features, input.tgt_in_seq.float_features, input.src_src_mask, input.tgt_tgt_mask, ) # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq, src_src_mask) batch_size, tgt_seq_len, _ = tgt_in_seq.shape # tgt_src_mask shape: batch_size, tgt_seq_len, src_seq_len tgt_src_mask = torch.ones(batch_size, tgt_seq_len, self.max_src_seq_len, device=src_src_mask.device) # decoder_output shape: batch_size, tgt_seq_len, dim_model decoder_output = self.decode( memory=encoder_output, state=state, tgt_src_mask=tgt_src_mask, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, tgt_seq_len=tgt_seq_len, ) # use the decoder's last step embedding to predict the slate reward pred_reward = self.proj(decoder_output[:, -1, :]) return rlt.RewardNetworkOutput(predicted_reward=pred_reward)
def forward(self, input: rlt.PreprocessedRankingInput): input = self._convert_seq2slate_to_reward_model_format(input) state = input.state.float_features tgt_in_seq = input.tgt_in_seq.float_features # shape: batch_size, src_seq_len + 1, dim_modle tgt_in_embed = self.embed(state, tgt_in_seq) # output shape: batch_size, src_seq_len + 1, dim_model output, hn = self.gru(tgt_in_embed) # hn shape: batch_size, dim_model hn = hn[-1] # top layer's hidden state # attention, using hidden as query, outputs as keys and values # shape: batch_size, src_seq_len + 1 attn_weights = F.softmax( torch.bmm( output, hn.unsqueeze(2) / torch.sqrt(torch.tensor(self.candidate_dim).float()), ).squeeze(2), dim=1, ) # shape: batch_size, dim_model context_vector = torch.bmm(attn_weights.unsqueeze(1), output).squeeze(1) # reward prediction depends on hidden state of the last step + context vector # shape: batch_size, 2 * dim_model seq_embed = torch.cat((hn, context_vector), dim=1) # shape: batch_size, 1 pred_reward = self.proj(seq_embed) return rlt.RewardNetworkOutput(predicted_reward=pred_reward)