class TransformerModel(nn.Module): def __init__(self, ninp, ntoken, ntoken_dec, nhid=2048, dropout=0): super(TransformerModel, self).__init__() self.model_type = 'Transformer' self.pos_encoder = PositionalEncoding(ninp, dropout) self.encoder = nn.Embedding(ntoken, ninp) self.ninp = ninp self.decoder_emb = nn.Embedding(ntoken_dec, ninp) self.decoder_out = nn.Linear(ninp, ntoken_dec) self.model = Transformer(d_model=ninp, dim_feedforward=nhid) def forward(self, src, tgt, src_mask, tgt_mask): src = self.encoder(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) tgt = self.decoder_emb(tgt) * math.sqrt(self.ninp) tgt = self.pos_encoder(tgt) src_mask = src_mask != 1 tgt_mask = tgt_mask != 1 subseq_mask = self.model.generate_square_subsequent_mask( tgt.size(1)).to(tgt.device) output = self.model(src.transpose(0, 1), tgt.transpose(0, 1), tgt_mask=subseq_mask, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask) output = self.decoder_out(output) return output def greedy_decode(self, src, src_mask, sos_token, max_length=20): src = self.encoder(src) * math.sqrt(self.ninp) src = self.pos_encoder(src) src_mask = src_mask != 1 encoded = self.model.encoder(src.transpose(0, 1), src_key_padding_mask=src_mask) generated = encoded.new_full((encoded.size(1), 1), sos_token, dtype=torch.long) for i in range(max_length - 1): subseq_mask = self.model.generate_square_subsequent_mask( generated.size(1)).to(src.device) decoder_in = self.decoder_emb(generated) * math.sqrt(self.ninp) decoder_in = self.pos_encoder(decoder_in) logits = self.decoder_out( self.model.decoder(decoder_in.transpose(0, 1), encoded, tgt_mask=subseq_mask, memory_key_padding_mask=src_mask)[-1, :, :]) new_generated = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, new_generated], dim=-1) return generated def save(self, file_dir): torch.save(self.state_dict(), file_dir) def load(self, file_dir): self.load_state_dict(torch.load(file_dir))
class Seq2SeqTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, nhead: int, src_vocab_size: int, tgt_vocab_size: int, dim_feedforward: int = 512, dropout: float = 0.1): super(Seq2SeqTransformer, self).__init__() self.transformer = Transformer(d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout) self.generator = nn.Linear(emb_size, tgt_vocab_size) self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) self.positional_encoding = PositionalEncoding( emb_size, dropout=dropout) def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask) return self.generator(outs) def encode(self, src: Tensor, src_mask: Tensor): return self.transformer.encoder(self.positional_encoding( self.src_tok_emb(src)), src_mask) def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): return self.transformer.decoder(self.positional_encoding( self.tgt_tok_emb(tgt)), memory, tgt_mask)
class SimpleTransformerModel(nn.Module): def __init__(self, nb_tokens: int, emb_size: int, nb_layers=2, nb_heads=4, hid_size=512, dropout=0.25, max_len=30): super(SimpleTransformerModel, self).__init__() from torch.nn import Transformer self.emb_size = emb_size self.max_len = max_len self.pos_encoder = PositionalEncoding(emb_size, dropout=dropout, max_len=max_len) self.embedder = nn.Embedding(nb_tokens, emb_size) self.transformer = Transformer(d_model=emb_size, nhead=nb_heads, num_encoder_layers=nb_layers, num_decoder_layers=nb_layers, dim_feedforward=hid_size, dropout=dropout) self.out_lin = nn.Linear(in_features=emb_size, out_features=nb_tokens) self.tgt_mask = None def _generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz), diagonal=1).to(device) mask = mask.masked_fill(mask == 1, float('-inf')) return mask def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-initrange, initrange) def enc_forward(self, src): # Embed source src = self.embedder(src) * math.sqrt(self.emb_size) # Add positional encoding + reshape into format (seq element, batch element, embedding) src = self.pos_encoder(src.view(src.shape[0], 1, src.shape[1])) # Push through encoder output = self.transformer.encoder(src) return output def dec_forward(self, memory, tgt): # Generate target mask, if necessary if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt): mask = self._generate_square_subsequent_mask(len(tgt)).to(device) self.tgt_mask = mask # Embed target tgt = self.embedder(tgt) * math.sqrt(self.emb_size) # Add positional encoding + reshape into format (seq element, batch element, embedding) tgt = self.pos_encoder(tgt.view(tgt.shape[0], 1, tgt.shape[1])) # Push through decoder + linear output layer output = self.out_lin(self.transformer.decoder(memory=memory, tgt=tgt, tgt_mask=self.tgt_mask)) # If using the model to evaluate, also take softmax of final layer to obtain probabilities if not self.training: output = torch.nn.functional.softmax(output, 2) return output def forward(self, src, tgt): memory = self.enc_forward(src) output = self.dec_forward(memory, tgt) return output def greedy_decode(self, src, max_len=None, start_symbol=0, stop_symbol=None): """ Greedy decode input "src": generate output character one at a time, until "stop_symbol" is generated or the output exceeds max_len, whichever comes first. :param src: input src, 1D tensor :param max_len: int :param start_symbol: int :param stop_symbol: int :return: decoded output sequence """ b_training = self.training if b_training: self.eval() if max_len is None: max_len = self.max_len elif max_len > self.max_len: raise ValueError(f"Parameter 'max_len' can not exceed model's own max_len," f" which is set at {self.max_len}.") # Get memory = output from encoder layer memory = model.enc_forward(src) # Initiate output buffer idxs = [start_symbol] # Keep track of last predicted symbol next_char = start_symbol # Keep generating output until "stop_symbol" is generated, or max_len is reached while next_char != stop_symbol: if len(idxs) == max_len: break # Convert output buffer to tensor ys = torch.LongTensor(idxs).to(device) # Push through decoder out = self.dec_forward(memory=memory, tgt=ys) # Get position of max probability of newly predicted character _, next_char = torch.max(out[-1, :, :], dim=1) next_char = next_char.item() # Append generated character to output buffer idxs.append(next_char) if b_training: self.train() return idxs
class TransformerModel(nn.Module): def __init__(self, vocab_size, hidden_size, num_attention_heads, num_encoder_layers, num_decoder_layers, intermediate_size, dropout=0.1): super(TransformerModel, self).__init__() # self.token_embeddings = nn.Embedding(vocab_size, hidden_size) self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=1) self.position_embeddings = PositionalEncoding(hidden_size) self.hidden_size = hidden_size self.dropout = nn.Dropout(p=dropout) self.transformer = Transformer( d_model=hidden_size, nhead=num_attention_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=intermediate_size, dropout=dropout, ) self.decoder_embeddings = nn.Linear(hidden_size, vocab_size) self.decoder_embeddings.weight = self.token_embeddings.weight self.init_weights() def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0)) return mask def init_weights(self): initrange = 0.1 self.token_embeddings.weight.data.uniform_(-initrange, initrange) self.decoder_embeddings.bias.data.zero_() self.decoder_embeddings.weight.data.uniform_(-initrange, initrange) def forward(self, src=None, tgt=None, memory=None, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): if src is not None: src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) if src_key_padding_mask is not None: src_key_padding_mask = src_key_padding_mask.t() if tgt is None: # encode memory = self.transformer.encoder( src_embeddings, src_key_padding_mask=src_key_padding_mask) return memory if tgt is not None: tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) if tgt_key_padding_mask is not None: tgt_key_padding_mask = tgt_key_padding_mask.t() if src is None and memory is not None: # decode if memory_key_padding_mask is not None: memory_key_padding_mask = memory_key_padding_mask.t() output = self.transformer.decoder( tgt_embeddings, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = self.decoder_embeddings(output) return output assert not (src is None and tgt is None) output = self.transformer(src_embeddings, tgt_embeddings, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask) output = self.decoder_embeddings(output) return output
class TransformerModel(nn.Module): def __init__(self, vocab_size, d_model, num_attention_heads, num_encoder_layers, num_decoder_layers, intermediate_size, max_len, dropout=0.1): super(TransformerModel, self).__init__() self.token_embeddings = nn.Embedding(vocab_size, d_model) self.position_embeddings = PositionalEncoding(d_model, max_len) self.hidden_size = d_model self.dropout = nn.Dropout(p=dropout) self.transformer = Transformer(d_model=d_model, nhead=num_attention_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=intermediate_size, dropout=dropout) self.decoder_embeddings = nn.Linear(d_model, vocab_size) self.decoder_embeddings.weight = self.token_embeddings.weight self.init_weights() def init_weights(self): initrange = 0.1 self.token_embeddings.weight.data.uniform_(-initrange, initrange) self.decoder_embeddings.bias.data.zero_() self.decoder_embeddings.weight.data.uniform_(-initrange, initrange) def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None): src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) output = self.transformer(src_embeddings, tgt_embeddings, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask) output = self.decoder_embeddings(output) return output def encode(self, src, src_key_padding_mask=None): src_embeddings = self.token_embeddings(src) * math.sqrt( self.hidden_size) + self.position_embeddings(src) src_embeddings = self.dropout(src_embeddings) memory = self.transformer.encoder( src_embeddings, src_key_padding_mask=src_key_padding_mask) return memory def decode(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None): tgt_embeddings = self.token_embeddings(tgt) * math.sqrt( self.hidden_size) + self.position_embeddings(tgt) tgt_embeddings = self.dropout(tgt_embeddings) tgt_mask = self.transformer.generate_square_subsequent_mask( tgt.size(0)).to(tgt.device) output = self.transformer.decoder( tgt_embeddings, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output = self.decoder_embeddings(output) return output
class MyTransformer(Model): def __init__( self, vocab: Vocabulary, source_embedder: TextFieldEmbedder, transformer: Dict, max_decoding_steps: int, target_namespace: str, target_embedder: TextFieldEmbedder = None, use_bleu: bool = True, ) -> None: super().__init__(vocab) self._target_namespace = target_namespace self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) if use_bleu: self._bleu = BLEU(exclude_indices={ self._pad_index, self._end_index, self._start_index }) else: self._bleu = None self._seq_acc = SequenceAccuracy() self._max_decoding_steps = max_decoding_steps self._source_embedder = source_embedder self._ndim = transformer["d_model"] self.pos_encoder = PositionalEncoding(self._ndim, transformer["dropout"]) num_classes = self.vocab.get_vocab_size(self._target_namespace) self._transformer = Transformer(**transformer) self._transformer.apply(inplace_relu) if target_embedder is None: self._target_embedder = self._source_embedder else: self._target_embedder = target_embedder self._output_projection_layer = Linear(self._ndim, num_classes) def _get_mask(self, meta_data): mask = torch.zeros(1, len(meta_data), self.vocab.get_vocab_size( self._target_namespace)).float() for bidx, md in enumerate(meta_data): for k, v in self.vocab._token_to_index[ self._target_namespace].items(): if 'position' in k and k not in md['avail_pos']: mask[:, bidx, v] = float('-inf') return mask def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == False, float('-inf')).masked_fill( mask == True, float(0.0)) return mask @overrides def forward( self, source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None, meta_data: Any = None, ) -> Dict[str, torch.Tensor]: src, src_key_padding_mask = self._encode(self._source_embedder, source_tokens) memory = self._transformer.encoder( src, src_key_padding_mask=src_key_padding_mask) if meta_data is not None: target_vocab_mask = self._get_mask(meta_data) target_vocab_mask = target_vocab_mask.to(memory.device) else: target_vocab_mask = None output_dict = {} targets = None if target_tokens: targets = target_tokens["tokens"][:, 1:] target_mask = (util.get_text_field_mask({"tokens": targets}) == 1) assert targets.size(1) <= self._max_decoding_steps if self.training and target_tokens: tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": target_tokens["tokens"][:, :-1]}) tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=src_key_padding_mask) logits = self._output_projection_layer(output) if target_vocab_mask is not None: logits += target_vocab_mask class_probabilities = F.softmax(logits.detach(), dim=-1) _, predictions = torch.max(class_probabilities, -1) logits = logits.transpose(0, 1) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss else: assert self.training is False output_dict["loss"] = torch.tensor(0.0).to(memory.device) if targets is not None: max_target_len = targets.size(1) else: max_target_len = None predictions, class_probabilities = self._decoder_step_by_step( memory, src_key_padding_mask, target_vocab_mask, max_target_len=max_target_len) predictions = predictions.transpose(0, 1) output_dict["predictions"] = predictions output_dict["class_probabilities"] = class_probabilities.transpose( 0, 1) if target_tokens: with torch.no_grad(): best_predictions = output_dict["predictions"] if self._bleu: self._bleu(best_predictions, targets) batch_size = targets.size(0) max_sz = max(best_predictions.size(1), targets.size(1), target_mask.size(1)) best_predictions_ = torch.zeros(batch_size, max_sz).to(memory.device) best_predictions_[:, :best_predictions. size(1)] = best_predictions targets_ = torch.zeros(batch_size, max_sz).to(memory.device) targets_[:, :targets.size(1)] = targets.cpu() target_mask_ = torch.zeros(batch_size, max_sz).to(memory.device) target_mask_[:, :target_mask.size(1)] = target_mask self._seq_acc(best_predictions_.unsqueeze(1), targets_, target_mask_) return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: predicted_indices = output_dict["predictions"] if not isinstance(predicted_indices, numpy.ndarray): # shape: (batch_size, num_decoding_steps) predicted_indices = predicted_indices.detach().cpu().numpy() # class_probabilities = output_dict["class_probabilities"].detach().cpu() # sample_predicted_indices = [] # for cp in class_probabilities: # sample = torch.multinomial(cp, num_samples=1) # sample_predicted_indices.append(sample) # # shape: (batch_size, num_decoding_steps, num_samples) # sample_predicted_indices = torch.stack(sample_predicted_indices) all_predicted_tokens = [] for indices in predicted_indices: # Beam search gives us the top k results for each source sentence in the batch # but we just want the single best. if len(indices.shape) > 1: indices = indices[0] indices = list(indices) # Collect indices till the first end_symbol if self._end_index in indices: indices = indices[:indices.index(self._end_index)] predicted_tokens = [ self.vocab.get_token_from_index( x, namespace=self._target_namespace) for x in indices ] all_predicted_tokens.append(predicted_tokens) output_dict["predicted_tokens"] = all_predicted_tokens return output_dict def _encode( self, embedder: TextFieldEmbedder, tokens: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: src = embedder(tokens) * math.sqrt(self._ndim) src = src.transpose(0, 1) src = self.pos_encoder(src) mask = util.get_text_field_mask(tokens) mask = (mask == 0) return src, mask def _decoder_step_by_step( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, target_vocab_mask: torch.Tensor = None, max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = memory.size(1) if getattr(self, "target_limit_decode_steps", False) and max_target_len is not None: num_decoding_steps = min(self._max_decoding_steps, max_target_len) print('decoding steps: ', num_decoding_steps) else: num_decoding_steps = self._max_decoding_steps last_predictions = memory.new_full( (batch_size, ), fill_value=self._start_index).long() step_predictions: List[torch.Tensor] = [] all_predicts = memory.new_full((batch_size, num_decoding_steps), fill_value=0).long() for timestep in range(num_decoding_steps): all_predicts[:, timestep] = last_predictions tgt, tgt_key_padding_mask = self._encode( self._target_embedder, {"tokens": all_predicts[:, :timestep + 1]}) tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to( memory.device) output = self._transformer.decoder( tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask) output_projections = self._output_projection_layer(output) if target_vocab_mask is not None: output_projections += target_vocab_mask class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, -1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes[timestep, :] step_predictions.append(last_predictions) if ((last_predictions == self._end_index) + (last_predictions == self._pad_index)).all(): break # shape: (num_decoding_steps, batch_size) predictions = torch.stack(step_predictions) return predictions, class_probabilities @staticmethod def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor, target_mask: torch.FloatTensor) -> torch.Tensor: logits = logits.contiguous() # shape: (batch_size, num_decoding_steps) relevant_targets = targets.contiguous() # shape: (batch_size, num_decoding_steps) relevant_mask = target_mask.contiguous() return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: all_metrics: Dict[str, float] = {} if self._bleu: all_metrics.update(self._bleu.get_metric(reset=reset)) all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset) return all_metrics def load_state_dict(self, state_dict, strict=True): new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[len('module.'):]] = v else: new_state_dict[k] = v super(MyTransformer, self).load_state_dict(new_state_dict, strict)
class FullTransformer(Module): def __init__(self, num_vocab, num_embedding=128, dim_feedforward=512, num_encoder_layer=4, num_decoder_layer=4, dropout=0.3, padding_idx=1, max_seq_len=140, nhead=8): super(FullTransformer, self).__init__() self.padding_idx = padding_idx # [x : seq_len, batch_size ] self.inp_embedding = Embedding(num_vocab, num_embedding, padding_idx=padding_idx) # [ x : seq_len, batch_size, num_embedding ] self.pos_embedding = PositionalEncoding(num_embedding, dropout, max_len=max_seq_len) self.trfm = Transformer(d_model=num_embedding, dim_feedforward=dim_feedforward, num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer, dropout=dropout, nhead=nhead) self.linear_out = torch.nn.Linear(num_embedding, num_vocab) def make_pad_mask(self, inp: torch.Tensor) -> torch.Tensor: """ Make mask attention that caused 'True' element will not be attended (ignored). Padding stated in self.padding_idx will not be attended at all. :param inp : input that to be masked in boolean Tensor """ return (inp == self.padding_idx).transpose(0, 1) def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: """ forward! :param src : source tensor :param tgt : target tensor """ # Generate mask for decoder attention tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to( tgt.device) # trg_mask shape = [target_seq_len, target_seq_len] src_pad_mask = self.make_pad_mask(src) tgt_pad_mask = self.make_pad_mask(tgt) # [ src : seq_len, batch_size, num_embedding ] out_emb_enc = self.pos_embedding(self.inp_embedding(src)) # [ src : seq_len, batch_size, num_embedding ] out_emb_dec = self.pos_embedding(self.inp_embedding(tgt)) out_trf = self.trfm(out_emb_enc, out_emb_dec, src_mask=None, tgt_mask=tgt_mask, memory_mask=None, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_pad_mask) # [ out_trf : seq_len, batch_size, num_embedding] out_to_logit = self.linear_out(out_trf) # final_out : [ seq_len, batch_size, vocab_size ] return out_to_logit def forward_encoder( self, src: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: src_pad_mask = self.make_pad_mask(src) out_emb_enc = self.pos_embedding(self.inp_embedding(src)) return self.trfm.encoder( out_emb_enc, src_key_padding_mask=src_pad_mask), src_pad_mask def forward_decoder(self, tgt: torch.Tensor, memory: torch.Tensor, src_pad_mask: torch.Tensor) -> torch.Tensor: tgt_pad_mask = self.make_pad_mask(tgt) out_emb_dec = self.pos_embedding(self.inp_embedding(tgt)) tgt_mask = self.trfm.generate_square_subsequent_mask(len(tgt)).to( tgt.device) out_trf = self.trfm.decoder(out_emb_dec, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_pad_mask) out_trf = self.linear_out(out_trf) return out_trf