def _convert_seq2slate_to_reward_model_format( self, input: rlt.PreprocessedRankingInput): """ In the reward model, the transformer decoder should see the full sequences; while in seq2slate, the decoder only sees the sequence before the last item. """ device = next(self.parameters()).device # pyre-fixme[16]: Optional type has no attribute `float_features`. batch_size, tgt_seq_len, candidate_dim = input.tgt_out_seq.float_features.shape assert self.max_tgt_seq_len == tgt_seq_len tgt_tgt_mask = subsequent_mask(tgt_seq_len + 1, device) # shape: batch_size, tgt_seq_len + 1, candidate_dim tgt_in_seq = torch.cat( ( self.decoder_start_vec.repeat(batch_size, 1, 1), input.tgt_out_seq.float_features, ), dim=1, ) return rlt.PreprocessedRankingInput.from_tensors( state=input.state.float_features, src_seq=input.src_seq.float_features, src_src_mask=input.src_src_mask, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, )
def process_tgt_seq(action): if action is not None: _, output_size = action.shape # Account for decoder starting symbol and padding symbol candidates_augment = torch.cat( ( torch.zeros( batch_size, 2, candidate_dim, device=device), candidates, ), dim=1, ) tgt_out_idx = action + 2 tgt_in_idx = torch.full((batch_size, output_size), DECODER_START_SYMBOL, device=device) tgt_in_idx[:, 1:] = tgt_out_idx[:, :-1] tgt_out_seq = gather(candidates_augment, tgt_out_idx) tgt_in_seq = torch.zeros(batch_size, output_size, candidate_dim, device=device) tgt_in_seq[:, 1:] = tgt_out_seq[:, :-1] tgt_tgt_mask = subsequent_mask(output_size, device) else: tgt_in_idx = None tgt_out_idx = None tgt_in_seq = None tgt_out_seq = None tgt_tgt_mask = None return tgt_in_idx, tgt_out_idx, tgt_in_seq, tgt_out_seq, tgt_tgt_mask
def test_subsequent_mask(self): expect_mask = torch.tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]) mask = subsequent_mask(3, torch.device("cpu")) assert torch.all(torch.eq(mask, expect_mask))
def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy): """ Decode sequences based on given inputs """ device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros( batch_size, src_seq_len + 2, candidate_dim, device=device ) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq, src_src_mask) ranked_per_symbol_probs = torch.zeros( batch_size, tgt_seq_len, candidate_size, device=device ) ranked_per_seq_probs = torch.zeros(batch_size, 1) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[ :, :tgt_seq_len ] # +2 to account for start symbol and padding symbol tgt_out_idx += 2 # every position has propensity of 1 because we are just using argsort ranked_per_symbol_probs = ranked_per_symbol_probs.scatter( 2, tgt_out_idx.unsqueeze(2), 1.0 ) ranked_per_seq_probs[:, :] = 1.0 return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx tgt_in_idx = ( torch.ones(batch_size, 1, device=device) .fill_(self._DECODER_START_SYMBOL) .type(torch.long) ) assert greedy is not None for l in range(tgt_seq_len): tgt_in_seq = ( candidate_features[ torch.arange(batch_size, device=device).repeat_interleave(l + 1), tgt_in_idx.flatten(), ] .view(batch_size, l + 1, -1) .to(device) ) tgt_src_mask = src_src_mask[:, : l + 1, :] # shape batch_size, l + 1, candidate_size logits = self.decode( memory=memory, state=state, tgt_src_mask=tgt_src_mask, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=subsequent_mask(l + 1, device), tgt_seq_len=l + 1, ) # next candidate shape: batch_size, 1 # prob shape: batch_size, candidate_size next_candidate, prob = self.generator( mode=self._DECODE_ONE_STEP_MODE, logits=logits, tgt_in_idx=tgt_in_idx, greedy=greedy, ) ranked_per_symbol_probs[:, l, :] = prob tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1) # remove the decoder start symbol # tgt_out_idx shape: batch_size, tgt_seq_len tgt_out_idx = tgt_in_idx[:, 1:] ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx ) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 # tgt_out_idx shape: batch_size, tgt_seq_len return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx