Esempio n. 1
0
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: str,
     greedy: Optional[bool] = None,
 ):
     # The creation of evaluation data pages only uses these specific arguments
     assert mode in (Seq2SlateMode.RANK_MODE,
                     Seq2SlateMode.PER_SEQ_LOG_PROB_MODE)
     if mode == Seq2SlateMode.RANK_MODE:
         assert greedy
         return rlt.RankingOutput(ranked_tgt_out_idx=torch.tensor(
             [[2, 3], [3, 2], [2, 3]]).long())
     return rlt.RankingOutput(
         log_probs=torch.log(torch.tensor([0.4, 0.3, 0.7])))
Esempio n. 2
0
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: str,
     tgt_seq_len: Optional[int] = None,
     greedy: Optional[bool] = None,
 ):
     res = self.data_parallel(
         input, mode=mode, tgt_seq_len=tgt_seq_len, greedy=greedy
     )
     if mode == RANK_MODE:
         return rlt.RankingOutput(
             ranked_tgt_out_idx=res[1], ranked_tgt_out_probs=res[0]
         )
     elif mode == LOG_PROB_MODE:
         return rlt.RankingOutput(log_probs=res)
Esempio n. 3
0
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: str,
     tgt_seq_len: Optional[int] = None,
     greedy: Optional[bool] = None,
 ):
     res = self.data_parallel(input,
                              mode=mode,
                              tgt_seq_len=tgt_seq_len,
                              greedy=greedy)
     if mode == Seq2SlateMode.RANK_MODE:
         return rlt.RankingOutput(ranked_tgt_out_idx=res[1],
                                  ranked_tgt_out_probs=res[0])
     elif mode in (
             Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE,
             Seq2SlateMode.PER_SEQ_LOG_PROB_MODE,
     ):
         return rlt.RankingOutput(log_probs=res)
     else:
         raise NotImplementedError()
Esempio n. 4
0
    def forward(self, input: rlt.PreprocessedRankingInput, mode: str, greedy: bool):
        # The creation of evaluation data pages only uses these specific arguments
        assert greedy and mode == RANK_MODE
        batch_size = input.state.float_features.shape[0]
        ranked_tgt_out_idx = []

        for i in range(batch_size):
            ranked_tgt_out_idx.append(self._forward(input.state.float_features[i]))

        return rlt.RankingOutput(
            ranked_tgt_out_idx=torch.tensor(ranked_tgt_out_idx).long()
        )