예제 #1
0
    def forward(self, input: rlt.PreprocessedRankingInput):
        input = self._convert_seq2slate_to_reward_model_format(input)
        state, src_seq, tgt_in_seq, src_src_mask, tgt_tgt_mask = (
            input.state.float_features,
            input.src_seq.float_features,
            input.tgt_in_seq.float_features,
            input.src_src_mask,
            input.tgt_tgt_mask,
        )
        # encoder_output shape: batch_size, src_seq_len, dim_model
        encoder_output = self.encode(state, src_seq, src_src_mask)

        batch_size, tgt_seq_len, _ = tgt_in_seq.shape
        # tgt_src_mask shape: batch_size, tgt_seq_len, src_seq_len
        tgt_src_mask = torch.ones(batch_size,
                                  tgt_seq_len,
                                  self.max_src_seq_len,
                                  device=src_src_mask.device)

        # decoder_output shape: batch_size, tgt_seq_len, dim_model
        decoder_output = self.decode(
            memory=encoder_output,
            state=state,
            tgt_src_mask=tgt_src_mask,
            tgt_in_seq=tgt_in_seq,
            tgt_tgt_mask=tgt_tgt_mask,
            tgt_seq_len=tgt_seq_len,
        )

        # use the decoder's last step embedding to predict the slate reward
        pred_reward = self.proj(decoder_output[:, -1, :])
        return rlt.RewardNetworkOutput(predicted_reward=pred_reward)
예제 #2
0
    def forward(self, input: rlt.PreprocessedRankingInput):
        input = self._convert_seq2slate_to_reward_model_format(input)
        state = input.state.float_features
        tgt_in_seq = input.tgt_in_seq.float_features

        # shape: batch_size, src_seq_len + 1, dim_modle
        tgt_in_embed = self.embed(state, tgt_in_seq)

        # output shape: batch_size, src_seq_len + 1, dim_model
        output, hn = self.gru(tgt_in_embed)
        # hn shape: batch_size, dim_model
        hn = hn[-1]  # top layer's hidden state

        # attention, using hidden as query, outputs as keys and values
        # shape: batch_size, src_seq_len + 1
        attn_weights = F.softmax(
            torch.bmm(
                output,
                hn.unsqueeze(2) /
                torch.sqrt(torch.tensor(self.candidate_dim).float()),
            ).squeeze(2),
            dim=1,
        )
        # shape: batch_size, dim_model
        context_vector = torch.bmm(attn_weights.unsqueeze(1),
                                   output).squeeze(1)

        # reward prediction depends on hidden state of the last step + context vector
        # shape: batch_size, 2 * dim_model
        seq_embed = torch.cat((hn, context_vector), dim=1)

        # shape: batch_size, 1
        pred_reward = self.proj(seq_embed)
        return rlt.RewardNetworkOutput(predicted_reward=pred_reward)