def forward( self, input: rlt.PreprocessedRankingInput, mode: str, greedy: Optional[bool] = None, ): # The creation of evaluation data pages only uses these specific arguments assert mode in (Seq2SlateMode.RANK_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE) if mode == Seq2SlateMode.RANK_MODE: assert greedy return rlt.RankingOutput(ranked_tgt_out_idx=torch.tensor( [[2, 3], [3, 2], [2, 3]]).long()) return rlt.RankingOutput( log_probs=torch.log(torch.tensor([0.4, 0.3, 0.7])))
def forward( self, input: rlt.PreprocessedRankingInput, mode: str, tgt_seq_len: Optional[int] = None, greedy: Optional[bool] = None, ): res = self.data_parallel( input, mode=mode, tgt_seq_len=tgt_seq_len, greedy=greedy ) if mode == RANK_MODE: return rlt.RankingOutput( ranked_tgt_out_idx=res[1], ranked_tgt_out_probs=res[0] ) elif mode == LOG_PROB_MODE: return rlt.RankingOutput(log_probs=res)
def forward( self, input: rlt.PreprocessedRankingInput, mode: str, tgt_seq_len: Optional[int] = None, greedy: Optional[bool] = None, ): res = self.data_parallel(input, mode=mode, tgt_seq_len=tgt_seq_len, greedy=greedy) if mode == Seq2SlateMode.RANK_MODE: return rlt.RankingOutput(ranked_tgt_out_idx=res[1], ranked_tgt_out_probs=res[0]) elif mode in ( Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE, ): return rlt.RankingOutput(log_probs=res) else: raise NotImplementedError()
def forward(self, input: rlt.PreprocessedRankingInput, mode: str, greedy: bool): # The creation of evaluation data pages only uses these specific arguments assert greedy and mode == RANK_MODE batch_size = input.state.float_features.shape[0] ranked_tgt_out_idx = [] for i in range(batch_size): ranked_tgt_out_idx.append(self._forward(input.state.float_features[i])) return rlt.RankingOutput( ranked_tgt_out_idx=torch.tensor(ranked_tgt_out_idx).long() )