def test_pytorch_decoder_mask(self): batch_size = 3 src_seq_len = 4 num_heads = 2 memory = torch.randn(batch_size, src_seq_len, num_heads) tgt_in_idx = torch.tensor([[1, 2, 3], [1, 4, 2], [1, 5, 4]]).long() tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( memory, tgt_in_idx, num_heads) expected_tgt_tgt_mask = (torch.tensor([ [False, True, True], [False, False, True], [False, False, False], ], ).unsqueeze(0).repeat(batch_size * num_heads, 1, 1)) expected_tgt_src_mask = torch.tensor([ [ [False, False, False, False], [True, False, False, False], [True, True, False, False], ], [ [False, False, False, False], [False, False, True, False], [True, False, True, False], ], [ [False, False, False, False], [False, False, False, True], [False, False, True, True], ], ]).repeat_interleave(num_heads, dim=0) assert torch.all(tgt_tgt_mask == expected_tgt_tgt_mask) assert torch.all(tgt_src_mask == expected_tgt_src_mask)
def decode(self, memory, state, tgt_in_idx, tgt_in_seq): # memory is the output of the encoder, the attention of each input symbol # memory shape: batch_size, src_seq_len, dim_model # tgt_in_idx shape: batch_size, tgt_seq_len # tgt_seq shape: batch_size, tgt_seq_len, dim_candidate batch_size, src_seq_len, _ = memory.shape _, tgt_seq_len = tgt_in_idx.shape candidate_size = src_seq_len + 2 if self.output_arch == Seq2SlateOutputArch.FRECHET_SORT: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) logits = torch.zeros(batch_size, tgt_seq_len, candidate_size).to( encoder_scores.device ) logits[:, :, :2] = float("-inf") logits[:, :, 2:] = encoder_scores.repeat(1, tgt_seq_len).reshape( batch_size, tgt_seq_len, src_seq_len ) logits = mask_logits_by_idx(logits, tgt_in_idx) probs = torch.softmax(logits, dim=2) elif self.output_arch == Seq2SlateOutputArch.AUTOREGRESSIVE: # candidate_embed shape: batch_size, tgt_seq_len, dim_model/2 candidate_embed = self.candidate_embedder(tgt_in_seq) # state_embed: batch_size, dim_model/2 state_embed = self.state_embedder(state) # state_embed: batch_size, tgt_seq_len, dim_model/2 state_embed = state_embed.repeat(1, tgt_seq_len).reshape( batch_size, tgt_seq_len, -1 ) # tgt_embed: batch_size, tgt_seq_len, dim_model tgt_embed = self.positional_encoding_decoder( torch.cat((state_embed, candidate_embed), dim=2) ) # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( memory, tgt_in_idx, self.num_heads ) # output of decoder is probabilities over symbols. # shape: batch_size, tgt_seq_len, candidate_size probs = self.decoder(tgt_embed, memory, tgt_src_mask, tgt_tgt_mask) else: raise NotImplementedError() return probs
def _log_probs( self, state, src_seq, tgt_in_seq, src_src_mask, tgt_tgt_mask, tgt_in_idx, tgt_out_idx, mode, ): """ Compute log of generative probabilities of given tgt sequences (used for REINFORCE training) """ # encoder_output shape: batch_size, src_seq_len, dim_model encoder_output = self.encode(state, src_seq, src_src_mask) tgt_seq_len = tgt_in_seq.shape[1] src_seq_len = src_seq.shape[1] assert tgt_seq_len <= src_seq_len # tgt_tgt_mask shape: batch_size * num_heads, tgt_seq_len, tgt_seq_len # tgt_src_mask shape: batch_size * num_heads, tgt_seq_len, src_seq_len tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( encoder_output, tgt_in_idx, self.num_heads) # decoder_probs shape: batch_size, tgt_seq_len, candidate_size decoder_probs = self.decode( memory=encoder_output, state=state, tgt_src_mask=tgt_src_mask, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, ) # log_probs shape: # if mode == PER_SEQ_LOG_PROB_MODE: batch_size, 1 # if mode == PER_SYMBOL_LOG_PROB_DIST_MODE: batch_size, tgt_seq_len, candidate_size if mode == Seq2SlateMode.PER_SYMBOL_LOG_PROB_DIST_MODE: per_symbol_log_probs = torch.log( torch.clamp(decoder_probs, min=EPSILON)) return per_symbol_log_probs per_seq_log_probs = torch.log( per_symbol_to_per_seq_probs(decoder_probs, tgt_out_idx)) return per_seq_log_probs
def _rank(self, state, src_seq, src_src_mask, tgt_seq_len, greedy): """ Decode sequences based on given inputs """ device = src_seq.device batch_size, src_seq_len, candidate_dim = src_seq.shape candidate_size = src_seq_len + 2 # candidate_features is used as look-up table for candidate features. # the second dim is src_seq_len + 2 because we also want to include # features of start symbol and padding symbol candidate_features = torch.zeros(batch_size, src_seq_len + 2, candidate_dim, device=device) # TODO: T62502977 create learnable feature vectors for start symbol # and padding symbol candidate_features[:, 2:, :] = src_seq # memory shape: batch_size, src_seq_len, dim_model memory = self.encode(state, src_seq, src_src_mask) ranked_per_symbol_probs = torch.zeros(batch_size, tgt_seq_len, candidate_size, device=device) ranked_per_seq_probs = torch.zeros(batch_size, 1) if self.output_arch == Seq2SlateOutputArch.ENCODER_SCORE: # encoder_scores shape: batch_size, src_seq_len encoder_scores = self.encoder_scorer(memory).squeeze(dim=2) tgt_out_idx = torch.argsort(encoder_scores, dim=1, descending=True)[:, :tgt_seq_len] # +2 to account for start symbol and padding symbol tgt_out_idx += 2 # every position has propensity of 1 because we are just using argsort ranked_per_symbol_probs = ranked_per_symbol_probs.scatter( 2, tgt_out_idx.unsqueeze(2), 1.0) ranked_per_seq_probs[:, :] = 1.0 return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx tgt_in_idx = (torch.ones(batch_size, 1, device=device).fill_( self._DECODER_START_SYMBOL).type(torch.long)) assert greedy is not None for l in range(tgt_seq_len): tgt_in_seq = gather(candidate_features, tgt_in_idx) tgt_tgt_mask, tgt_src_mask = pytorch_decoder_mask( memory, tgt_in_idx, self.num_heads) # shape batch_size, l + 1, candidate_size probs = self.decode( memory=memory, state=state, tgt_src_mask=tgt_src_mask, tgt_in_idx=tgt_in_idx, tgt_in_seq=tgt_in_seq, tgt_tgt_mask=tgt_tgt_mask, ) # next candidate shape: batch_size, 1 # prob shape: batch_size, candidate_size next_candidate, next_candidate_sample_prob = self.generator( probs, greedy) ranked_per_symbol_probs[:, l, :] = next_candidate_sample_prob tgt_in_idx = torch.cat([tgt_in_idx, next_candidate], dim=1) # remove the decoder start symbol # tgt_out_idx shape: batch_size, tgt_seq_len tgt_out_idx = tgt_in_idx[:, 1:] ranked_per_seq_probs = per_symbol_to_per_seq_probs( ranked_per_symbol_probs, tgt_out_idx) # ranked_per_symbol_probs shape: batch_size, tgt_seq_len, candidate_size # ranked_per_seq_probs shape: batch_size, 1 # tgt_out_idx shape: batch_size, tgt_seq_len return ranked_per_symbol_probs, ranked_per_seq_probs, tgt_out_idx