def test_albert_seq2seq_init(self): path_dir_name = os.path.dirname(os.path.realpath(__file__)) data_path = os.path.join(path_dir_name, "sample.txt") with tempfile.TemporaryDirectory() as tmpdirname: processor = TextProcessor() processor.train_tokenizer([data_path], vocab_size=1000, to_save_dir=tmpdirname, languages={ "<en>": 0, "<fa>": 1 }) seq2seq = Seq2Seq(text_processor=processor) src_inputs = torch.tensor([[ 1, 2, 3, 4, 5, processor.pad_token_id(), processor.pad_token_id() ], [1, 2, 3, 4, 5, 6, processor.pad_token_id()]]) tgt_inputs = torch.tensor( [[6, 8, 7, processor.pad_token_id(), processor.pad_token_id()], [6, 8, 7, 8, processor.pad_token_id()]]) src_mask = (src_inputs != processor.pad_token_id()) tgt_mask = (tgt_inputs != processor.pad_token_id()) src_langs = torch.tensor([[0], [0]]).squeeze() tgt_langs = torch.tensor([[1], [1]]).squeeze() seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask, src_langs, tgt_langs, log_softmax=True) assert list(seq_output.size()) == [5, processor.vocab_size()] seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask, src_langs, tgt_langs) assert list(seq_output.size()) == [5, processor.vocab_size()]
def mass_mask(mask_prob, pad_indices, src_text, text_processor: TextProcessor) -> Dict: """ 20% of times, mask from start to middle 20% of times, mask from middle to end 60% of times, mask a random index """ index_range = pad_indices - (1 - mask_prob) * pad_indices src_mask = torch.zeros(src_text.size(), dtype=torch.bool) to_recover = [] to_recover_pos = [] for i, irange in enumerate(index_range): range_size = int(pad_indices[i] / 2) r = random.random() last_idx = int(math.ceil(irange)) if r > 0.8: start = 1 elif r > 0.6: start = last_idx else: start = random.randint(2, last_idx) if last_idx >= 2 else 2 end = start + range_size src_mask[i, start:end] = True to_recover.append(src_text[i, start - 1:end]) to_recover_pos.append(torch.arange(start - 1, end)) to_recover = pad_sequence(to_recover, batch_first=True, padding_value=text_processor.pad_token_id()) to_recover_pos = pad_sequence(to_recover_pos, batch_first=True, padding_value=int(src_text.size(-1)) - 1) assert 0 < mask_prob < 1 masked_ids = src_text[:, 1:][src_mask[:, 1:]] mask_idx = src_text[src_mask] random_index = lambda: random.randint(len(text_processor.special_tokens), text_processor.vocab_size() - 1) rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else ( random_index() if r < 0.9 else int(mask_idx[c])) replacements = list( map(lambda i: rand_select(random.random(), i), range(mask_idx.size(0)))) src_text[src_mask] = torch.LongTensor(replacements) return { "src_mask": src_mask, "targets": masked_ids, "src_text": src_text, "to_recover": to_recover, "positions": to_recover_pos, "mask_idx": mask_idx }
def mask_text(mask_prob, pads, texts, text_processor: TextProcessor, mask_eos: bool = True): assert 0 < mask_prob < 1 mask = torch.empty(texts.size()).uniform_(0, 1) < mask_prob mask[~pads] = False # We should not mask pads. if not mask_eos: eos_idx = texts == text_processor.sep_token_id() mask[ eos_idx] = False # We should not mask end-of-sentence (usually in case of BART training). masked_ids = texts[mask] random_index = lambda: random.randint(len(text_processor.special_tokens), text_processor.vocab_size() - 1) rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else ( random_index() if r < 0.9 else int(masked_ids[c])) replacements = list( map(lambda i: rand_select(random.random(), i), range(masked_ids.size(0)))) texts[mask] = torch.LongTensor(replacements) return mask, masked_ids, texts