def test_per_symbol_to_per_seq_probs(self): batch_size = 1 seq_len = 3 candidate_size = seq_len + 2 tgt_out_idx = torch.tensor([[0, 2, 1]]) + 2 per_symbol_log_probs = torch.randn(batch_size, seq_len, candidate_size) per_symbol_log_probs[0, :, :2] = float("-inf") per_symbol_log_probs[0, 1, 2] = float("-inf") per_symbol_log_probs[0, 2, 2] = float("-inf") per_symbol_log_probs[0, 2, 4] = float("-inf") per_symbol_log_probs = F.log_softmax(per_symbol_log_probs, dim=2) per_symbol_probs = torch.exp(per_symbol_log_probs) expect_per_seq_probs = ( per_symbol_probs[0, 0, 2] * per_symbol_probs[0, 1, 4] * per_symbol_probs[0, 2, 3] ) computed_per_seq_probs = per_symbol_to_per_seq_probs( per_symbol_probs, tgt_out_idx ) np.testing.assert_allclose( expect_per_seq_probs, computed_per_seq_probs, atol=0.001, rtol=0.0 )
def _log_probs( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_in_seq: torch.Tensor, tgt_in_idx: torch.Tensor, tgt_out_idx: torch.Tensor, mode: str, ) -> Seq2SlateTransformerOutput: """ Compute log of generative probabilities of given tgt sequences (used for REINFORCE training) """ # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq) tgt_seq_len = tgt_in_seq.shape[1] src_seq_len = src_seq.shape[1] assert tgt_seq_len <= src_seq_len # decoder_probs shape: batch_size, tgt_seq_len, candidate_size decoder_probs = self.decode( memory=encoder_output, state=state, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, ) # log_probs shape: # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1 # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size if mode == self._PER_SYMBOL_LOG_PROB_DIST_MODE: per_symbol_log_probs = torch.log(torch.clamp(decoder_probs, min=1e-40)) return Seq2SlateTransformerOutput( ranked_per_symbol_probs=None, ranked_per_seq_probs=None, ranked_tgt_out_idx=None, per_symbol_log_probs=per_symbol_log_probs, per_seq_log_probs=None, encoder_scores=None, ) per_seq_log_probs = torch.log( per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx) ) return Seq2SlateTransformerOutput( ranked_per_symbol_probs=None, ranked_per_seq_probs=None, ranked_tgt_out_idx=None, per_symbol_log_probs=None, per_seq_log_probs=per_seq_log_probs, encoder_scores=None, )
def _rank( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_seq_len: int, greedy: bool ) -> Seq2SlateTransformerOutput: """Decode sequences based on given inputs""" device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros( batch_size, src_seq_len + 2, candidate_dim, device=device ) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: tgt_out_idx, ranked_per_symbol_probs = self._encoder_rank( memory, tgt_seq_len ) elif self.output_arch == Seq2SlateOutputArch.FRECHET_SORT and greedy: # greedy decoding for non-autoregressive decoder tgt_out_idx, ranked_per_symbol_probs = self._greedy_rank( state, memory, candidate_features, tgt_seq_len ) else: assert greedy is not None # autoregressive decoding tgt_out_idx, ranked_per_symbol_probs = self._autoregressive_rank( state, memory, candidate_features, tgt_seq_len, greedy ) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx ) # tgt_out_idx shape: batch_size, tgt_seq_len return Seq2SlateTransformerOutput( ranked_per_symbol_probs=ranked_per_symbol_probs, ranked_per_seq_probs=ranked_per_seq_probs, ranked_tgt_out_idx=tgt_out_idx, per_symbol_log_probs=self._OUTPUT_PLACEHOLDER, per_seq_log_probs=self._OUTPUT_PLACEHOLDER, encoder_scores=self._OUTPUT_PLACEHOLDER, )
def _log_probs( self, state, src_seq, tgt_in_seq, src_src_mask, tgt_tgt_mask, tgt_in_idx, tgt_out_idx, mode, ): """ Compute log of generative probabilities of given tgt sequences (used for REINFORCE training) """ # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq, src_src_mask) tgt_seq_len = tgt_in_seq.shape[1] src_seq_len = src_seq.shape[1] assert tgt_seq_len <= src_seq_len # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( encoder_output, tgt_in_idx, self.num_heads) # decoder_probs shape: batch_size, tgt_seq_len, candidate_size decoder_probs = self.decode( memory=encoder_output, state=state, tgt_src_mask=tgt_src_mask, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, ) # log_probs shape: # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1 # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size if mode == Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE: per_symbol_log_probs = torch.log( torch.clamp(decoder_probs, min=EPSILON)) return per_symbol_log_probs per_seq_log_probs = torch.log( per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx)) return per_seq_log_probs
def _rank( self, state: torch.Tensor, src_seq: torch.Tensor, tgt_seq_len: int, greedy: bool ) -> Seq2SlateTransformerOutput: """ Decode sequences based on given inputs """ device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros( batch_size, src_seq_len + 2, candidate_dim, device=device ) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq) ranked_per_symbol_probs = torch.zeros( batch_size, tgt_seq_len, candidate_size, device=device ) ranked_per_seq_probs = torch.zeros(batch_size, 1) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[ :, :tgt_seq_len ] # +2 to account for start symbol and padding symbol tgt_out_idx += 2 # every position has propensity of 1 because we are just using argsort ranked_per_symbol_probs = ranked_per_symbol_probs.scatter( 2, tgt_out_idx.unsqueeze(2), 1.0 ) ranked_per_seq_probs[:, :] = 1.0 return Seq2SlateTransformerOutput( ranked_per_symbol_probs=ranked_per_symbol_probs, ranked_per_seq_probs=ranked_per_seq_probs, ranked_tgt_out_idx=tgt_out_idx, per_symbol_log_probs=self._OUTPUT_PLACEHOLDER, per_seq_log_probs=self._OUTPUT_PLACEHOLDER, encoder_scores=self._OUTPUT_PLACEHOLDER, ) tgt_in_idx = ( torch.ones(batch_size, 1, device=device) .fill_(self._DECODER_START_SYMBOL) .long() ) assert greedy is not None for l in range(tgt_seq_len): tgt_in_seq = gather(candidate_features, tgt_in_idx) # shape batch_size, l + 1, candidate_size probs = self.decode( memory=memory, state=state, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, ) # next candidate shape: batch_size, 1 # prob shape: batch_size, candidate_size next_candidate, next_candidate_sample_prob = self.generator(probs, greedy) ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1) # remove the decoder start symbol # tgt_out_idx shape: batch_size, tgt_seq_len tgt_out_idx = tgt_in_idx[:, 1:] ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx ) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 # tgt_out_idx shape: batch_size, tgt_seq_len return Seq2SlateTransformerOutput( ranked_per_symbol_probs=ranked_per_symbol_probs, ranked_per_seq_probs=ranked_per_seq_probs, ranked_tgt_out_idx=tgt_out_idx, per_symbol_log_probs=self._OUTPUT_PLACEHOLDER, per_seq_log_probs=self._OUTPUT_PLACEHOLDER, encoder_scores=self._OUTPUT_PLACEHOLDER, )
def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy): """ Decode sequences based on given inputs """ device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros( batch_size, src_seq_len + 2, candidate_dim, device=device ) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq, src_src_mask) ranked_per_symbol_probs = torch.zeros( batch_size, tgt_seq_len, candidate_size, device=device ) ranked_per_seq_probs = torch.zeros(batch_size, 1) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[ :, :tgt_seq_len ] # +2 to account for start symbol and padding symbol tgt_out_idx += 2 # every position has propensity of 1 because we are just using argsort ranked_per_symbol_probs = ranked_per_symbol_probs.scatter( 2, tgt_out_idx.unsqueeze(2), 1.0 ) ranked_per_seq_probs[:, :] = 1.0 return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx tgt_in_idx = ( torch.ones(batch_size, 1, device=device) .fill_(self._DECODER_START_SYMBOL) .type(torch.long) ) assert greedy is not None for l in range(tgt_seq_len): tgt_in_seq = ( candidate_features[ torch.arange(batch_size, device=device).repeat_interleave(l + 1), tgt_in_idx.flatten(), ] .view(batch_size, l + 1, -1) .to(device) ) tgt_src_mask = src_src_mask[:, : l + 1, :] # shape batch_size, l + 1, candidate_size logits = self.decode( memory=memory, state=state, tgt_src_mask=tgt_src_mask, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=subsequent_mask(l + 1, device), tgt_seq_len=l + 1, ) # next candidate shape: batch_size, 1 # prob shape: batch_size, candidate_size next_candidate, prob = self.generator( mode=self._DECODE_ONE_STEP_MODE, logits=logits, tgt_in_idx=tgt_in_idx, greedy=greedy, ) ranked_per_symbol_probs[:, l, :] = prob tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1) # remove the decoder start symbol # tgt_out_idx shape: batch_size, tgt_seq_len tgt_out_idx = tgt_in_idx[:, 1:] ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx ) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 # tgt_out_idx shape: batch_size, tgt_seq_len return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx