def test_albert_seq2seq_init(self):
        path_dir_name = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(path_dir_name, "sample.txt")

        with tempfile.TemporaryDirectory() as tmpdirname:
            processor = TextProcessor()
            processor.train_tokenizer([data_path],
                                      vocab_size=1000,
                                      to_save_dir=tmpdirname,
                                      languages={
                                          "<en>": 0,
                                          "<fa>": 1
                                      })
            seq2seq = Seq2Seq(text_processor=processor)
            src_inputs = torch.tensor([[
                1, 2, 3, 4, 5,
                processor.pad_token_id(),
                processor.pad_token_id()
            ], [1, 2, 3, 4, 5, 6, processor.pad_token_id()]])
            tgt_inputs = torch.tensor(
                [[6, 8, 7,
                  processor.pad_token_id(),
                  processor.pad_token_id()],
                 [6, 8, 7, 8, processor.pad_token_id()]])
            src_mask = (src_inputs != processor.pad_token_id())
            tgt_mask = (tgt_inputs != processor.pad_token_id())
            src_langs = torch.tensor([[0], [0]]).squeeze()
            tgt_langs = torch.tensor([[1], [1]]).squeeze()
            seq_output = seq2seq(src_inputs,
                                 tgt_inputs,
                                 src_mask,
                                 tgt_mask,
                                 src_langs,
                                 tgt_langs,
                                 log_softmax=True)
            assert list(seq_output.size()) == [5, processor.vocab_size()]

            seq_output = seq2seq(src_inputs, tgt_inputs, src_mask, tgt_mask,
                                 src_langs, tgt_langs)
            assert list(seq_output.size()) == [5, processor.vocab_size()]
示例#2
0
def mass_mask(mask_prob, pad_indices, src_text,
              text_processor: TextProcessor) -> Dict:
    """
        20% of times, mask from start to middle
        20% of times, mask from middle to end
        60% of times, mask a random index
    """
    index_range = pad_indices - (1 - mask_prob) * pad_indices
    src_mask = torch.zeros(src_text.size(), dtype=torch.bool)
    to_recover = []
    to_recover_pos = []
    for i, irange in enumerate(index_range):
        range_size = int(pad_indices[i] / 2)
        r = random.random()
        last_idx = int(math.ceil(irange))
        if r > 0.8:
            start = 1
        elif r > 0.6:
            start = last_idx
        else:
            start = random.randint(2, last_idx) if last_idx >= 2 else 2

        end = start + range_size
        src_mask[i, start:end] = True
        to_recover.append(src_text[i, start - 1:end])
        to_recover_pos.append(torch.arange(start - 1, end))
    to_recover = pad_sequence(to_recover,
                              batch_first=True,
                              padding_value=text_processor.pad_token_id())
    to_recover_pos = pad_sequence(to_recover_pos,
                                  batch_first=True,
                                  padding_value=int(src_text.size(-1)) - 1)

    assert 0 < mask_prob < 1
    masked_ids = src_text[:, 1:][src_mask[:, 1:]]
    mask_idx = src_text[src_mask]
    random_index = lambda: random.randint(len(text_processor.special_tokens),
                                          text_processor.vocab_size() - 1)
    rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else (
        random_index() if r < 0.9 else int(mask_idx[c]))
    replacements = list(
        map(lambda i: rand_select(random.random(), i),
            range(mask_idx.size(0))))
    src_text[src_mask] = torch.LongTensor(replacements)
    return {
        "src_mask": src_mask,
        "targets": masked_ids,
        "src_text": src_text,
        "to_recover": to_recover,
        "positions": to_recover_pos,
        "mask_idx": mask_idx
    }
示例#3
0
def mask_text(mask_prob,
              pads,
              texts,
              text_processor: TextProcessor,
              mask_eos: bool = True):
    assert 0 < mask_prob < 1
    mask = torch.empty(texts.size()).uniform_(0, 1) < mask_prob
    mask[~pads] = False  # We should not mask pads.
    if not mask_eos:
        eos_idx = texts == text_processor.sep_token_id()
        mask[
            eos_idx] = False  # We should not mask end-of-sentence (usually in case of BART training).

    masked_ids = texts[mask]
    random_index = lambda: random.randint(len(text_processor.special_tokens),
                                          text_processor.vocab_size() - 1)
    rand_select = lambda r, c: text_processor.mask_token_id() if r < 0.8 else (
        random_index() if r < 0.9 else int(masked_ids[c]))
    replacements = list(
        map(lambda i: rand_select(random.random(), i),
            range(masked_ids.size(0))))
    texts[mask] = torch.LongTensor(replacements)
    return mask, masked_ids, texts