def test_mask_logits_by_idx(self): logits = torch.tensor( [ [ [1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 3.0, 4.0, 5.0, 6.0], [3.0, 4.0, 5.0, 6.0, 7.0], ], [ [5.0, 4.0, 3.0, 2.0, 1.0], [6.0, 5.0, 4.0, 3.0, 2.0], [7.0, 6.0, 5.0, 4.0, 3.0], ], ] ) tgt_in_idx = torch.tensor( [[DECODER_START_SYMBOL, 2, 3], [DECODER_START_SYMBOL, 4, 3]] ) masked_logits = mask_logits_by_idx(logits, tgt_in_idx) expected_logits = torch.tensor( [ [ [float("-inf"), float("-inf"), 3.0, 4.0, 5.0], [float("-inf"), float("-inf"), float("-inf"), 5.0, 6.0], [float("-inf"), float("-inf"), float("-inf"), float("-inf"), 7.0], ], [ [float("-inf"), float("-inf"), 3.0, 2.0, 1.0], [float("-inf"), float("-inf"), 4.0, 3.0, float("-inf")], [float("-inf"), float("-inf"), 5.0, float("-inf"), float("-inf")], ], ] ) assert torch.all(torch.eq(masked_logits, expected_logits))
def decode(self, memory, state, tgt_in_idx, tgt_in_seq): # memory is the output of the encoder, the attention of each input symbol # memory shape: batch_size, src_seq_len, dim_model # tgt_in_idx shape: batch_size, tgt_seq_len # tgt_seq shape: batch_size, tgt_seq_len, dim_candidate batch_size, src_seq_len, _ = memory.shape _, tgt_seq_len = tgt_in_idx.shape candidate_size = src_seq_len + 2 if self.output_arch == Seq2SlateOutputArch.FRECHET_SORT: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) logits = torch.zeros(batch_size, tgt_seq_len, candidate_size).to( encoder_scores.device ) logits[:, :, :2] = float("-inf") logits[:, :, 2:] = encoder_scores.repeat(1, tgt_seq_len).reshape( batch_size, tgt_seq_len, src_seq_len ) logits = mask_logits_by_idx(logits, tgt_in_idx) probs = torch.softmax(logits, dim=2) elif self.output_arch == Seq2SlateOutputArch.AUTOREGRESSIVE: # candidate_embed shape: batch_size, tgt_seq_len, dim_model/2 candidate_embed = self.candidate_embedder(tgt_in_seq) # state_embed: batch_size, dim_model/2 state_embed = self.state_embedder(state) # state_embed: batch_size, tgt_seq_len, dim_model/2 state_embed = state_embed.repeat(1, tgt_seq_len).reshape( batch_size, tgt_seq_len, -1 ) # tgt_embed: batch_size, tgt_seq_len, dim_model tgt_embed = self.positional_encoding_decoder( torch.cat((state_embed, candidate_embed), dim=2) ) # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( memory, tgt_in_idx, self.num_heads ) # output of decoder is probabilities over symbols. # shape: batch_size, tgt_seq_len, candidate_size probs = self.decoder(tgt_embed, memory, tgt_src_mask, tgt_tgt_mask) else: raise NotImplementedError() return probs
def _log_probs(self, logits, tgt_in_idx, mode): """ Return the log probability distribution at each decoding step :param logits: logits of decoder outputs. Shape: batch_size, seq_len, candidate_size :param tgt_idx: the indices of candidates in decoder input sequences. The first symbol is always DECODER_START_SYMBOL. Shape: batch_size, seq_len """ assert mode in ( Seq2SlateMode.PER_SEQ_LOG_PROB_MODE, Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE, ) logits = mask_logits_by_idx(logits, tgt_in_idx) # log_probs shape: batch_size, seq_len, candidate_size log_probs = F.log_softmax(logits / self.temperature, dim=2) return log_probs