def forward( self, input: rlt.PreprocessedRankingInput, mode: str, tgt_seq_len: Optional[int] = None, greedy: Optional[bool] = None, ): res = self.data_parallel( input, mode=mode, tgt_seq_len=tgt_seq_len, greedy=greedy ) if mode == Seq2SlateMode.RANK_MODE: return rlt.RankingOutput( ranked_per_symbol_probs=res[0], ranked_per_seq_probs=res[1], ranked_tgt_out_idx=res[2], ) elif mode in ( Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE, ): return rlt.RankingOutput(log_probs=res) elif mode == Seq2SlateMode.ENCODER_SCORE_MODE: return rlt.RankingOutput(encoder_scores=res) else: raise NotImplementedError()
def forward( self, input: rlt.PreprocessedRankingInput, mode: Seq2SlateMode, tgt_seq_len: Optional[int] = None, greedy: Optional[bool] = None, ): if mode == Seq2SlateMode.RANK_MODE: # pyre-fixme[16]: `Seq2SlateNet` has no attribute `seq2slate`. res = self.seq2slate( mode=mode.value, state=input.state.float_features, src_seq=input.src_seq.float_features, tgt_seq_len=tgt_seq_len, greedy=greedy, ) return rlt.RankingOutput( ranked_per_symbol_probs=res.ranked_per_symbol_probs, ranked_per_seq_probs=res.ranked_per_seq_probs, ranked_tgt_out_idx=res.ranked_tgt_out_idx, ) elif mode in ( Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE, ): assert input.tgt_in_seq is not None assert input.tgt_in_idx is not None assert input.tgt_out_idx is not None res = self.seq2slate( mode=mode.value, state=input.state.float_features, src_seq=input.src_seq.float_features, # pyre-fixme[16]: `Optional` has no attribute `float_features`. tgt_in_seq=input.tgt_in_seq.float_features, tgt_in_idx=input.tgt_in_idx, tgt_out_idx=input.tgt_out_idx, ) if res.per_symbol_log_probs is not None: log_probs = res.per_symbol_log_probs else: log_probs = res.per_seq_log_probs return rlt.RankingOutput(log_probs=log_probs) elif mode == Seq2SlateMode.ENCODER_SCORE_MODE: assert input.tgt_out_idx is not None res = self.seq2slate( mode=mode.value, state=input.state.float_features, src_seq=input.src_seq.float_features, tgt_out_idx=input.tgt_out_idx, ) return rlt.RankingOutput(encoder_scores=res.encoder_scores) else: raise NotImplementedError()
def forward( self, input: rlt.PreprocessedRankingInput, mode: str, greedy: Optional[bool] = None, ): # The creation of evaluation data pages only uses these specific arguments assert mode in (Seq2SlateMode.RANK_MODE, Seq2SlateMode.PER_SEQ_LOG_PROB_MODE) if mode == Seq2SlateMode.RANK_MODE: assert greedy return rlt.RankingOutput( ranked_tgt_out_idx=torch.tensor([[2, 3], [3, 2], [2, 3]]).long() ) return rlt.RankingOutput(log_probs=torch.log(torch.tensor([0.4, 0.3, 0.7])))