def create_src_masks(src, SRC_SEQ_LEN, TEXT, use_srcmask=False): if use_srcmask: src_mask = Transformer.generate_square_subsequent_mask(SRC_SEQ_LEN).to( device) else: src_mask = None src_key_padding_mask = (src == TEXT.vocab.stoi['<pad>']).bool().to(device) memory_key_padding_mask = ( src == TEXT.vocab.stoi['<pad>']).bool().to(device) return src_mask, src_key_padding_mask, memory_key_padding_mask
def create_tgt_masks(tgt, TGT_SEQ_LEN, LABEL): tgt_mask = Transformer.generate_square_subsequent_mask(TGT_SEQ_LEN).to( device) tgt_key_padding_mask = (tgt == LABEL.vocab.stoi['<pad>']).bool().to(device) return tgt_mask, tgt_key_padding_mask