예제 #1
0
파일: seq2slate.py 프로젝트: saonam/ReAgent
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: str,
     tgt_seq_len: Optional[int] = None,
     greedy: Optional[bool] = None,
 ):
     res = self.data_parallel(
         input, mode=mode, tgt_seq_len=tgt_seq_len, greedy=greedy
     )
     if mode == Seq2SlateMode.RANK_MODE:
         return rlt.RankingOutput(
             ranked_per_symbol_probs=res[0],
             ranked_per_seq_probs=res[1],
             ranked_tgt_out_idx=res[2],
         )
     elif mode in (
         Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE,
         Seq2SlateMode.PER_SEQ_LOG_PROB_MODE,
     ):
         return rlt.RankingOutput(log_probs=res)
     elif mode == Seq2SlateMode.ENCODER_SCORE_MODE:
         return rlt.RankingOutput(encoder_scores=res)
     else:
         raise NotImplementedError()
예제 #2
0
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: Seq2SlateMode,
     tgt_seq_len: Optional[int] = None,
     greedy: Optional[bool] = None,
 ):
     if mode == Seq2SlateMode.RANK_MODE:
         # pyre-fixme[16]: `Seq2SlateNet` has no attribute `seq2slate`.
         res = self.seq2slate(
             mode=mode.value,
             state=input.state.float_features,
             src_seq=input.src_seq.float_features,
             tgt_seq_len=tgt_seq_len,
             greedy=greedy,
         )
         return rlt.RankingOutput(
             ranked_per_symbol_probs=res.ranked_per_symbol_probs,
             ranked_per_seq_probs=res.ranked_per_seq_probs,
             ranked_tgt_out_idx=res.ranked_tgt_out_idx,
         )
     elif mode in (
             Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE,
             Seq2SlateMode.PER_SEQ_LOG_PROB_MODE,
     ):
         assert input.tgt_in_seq is not None
         assert input.tgt_in_idx is not None
         assert input.tgt_out_idx is not None
         res = self.seq2slate(
             mode=mode.value,
             state=input.state.float_features,
             src_seq=input.src_seq.float_features,
             # pyre-fixme[16]: `Optional` has no attribute `float_features`.
             tgt_in_seq=input.tgt_in_seq.float_features,
             tgt_in_idx=input.tgt_in_idx,
             tgt_out_idx=input.tgt_out_idx,
         )
         if res.per_symbol_log_probs is not None:
             log_probs = res.per_symbol_log_probs
         else:
             log_probs = res.per_seq_log_probs
         return rlt.RankingOutput(log_probs=log_probs)
     elif mode == Seq2SlateMode.ENCODER_SCORE_MODE:
         assert input.tgt_out_idx is not None
         res = self.seq2slate(
             mode=mode.value,
             state=input.state.float_features,
             src_seq=input.src_seq.float_features,
             tgt_out_idx=input.tgt_out_idx,
         )
         return rlt.RankingOutput(encoder_scores=res.encoder_scores)
     else:
         raise NotImplementedError()
예제 #3
0
 def forward(
     self,
     input: rlt.PreprocessedRankingInput,
     mode: str,
     greedy: Optional[bool] = None,
 ):
     # The creation of evaluation data pages only uses these specific arguments
     assert mode in (Seq2SlateMode.RANK_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE)
     if mode == Seq2SlateMode.RANK_MODE:
         assert greedy
         return rlt.RankingOutput(
             ranked_tgt_out_idx=torch.tensor([[2, 3], [3, 2], [2, 3]]).long()
         )
     return rlt.RankingOutput(log_probs=torch.log(torch.tensor([0.4, 0.3, 0.7])))