Exemplo n.º 1
0
def CDLM_ner(list_of_words_for_a_sentence,
             _tokenizer,
             is_zh,
             keep_origin_rate=0.2,
             max_ratio=0.2,
             max_num=4):
    """

    :params
        list_of_words_for_a_sentence (list): ['I', 'am', 'a', 'student']
        tokenizer (object): tfds tokenizer object
        lan_index (int): index for language embeddings, could be 0 or 1
        min_num (int):
        max_num (int):
        max_ratio (float):
        keep_origin_rate (float):
        language (str): zh or en or both
    :returns
        masked_input (list): list of encoded and masked token idx
        list_of_tar_token_idx (list):
        list_of_lan_idx (list):
    """

    # get n grams
    list_of_token = list(
        map(lambda x: x.strip(), list_of_words_for_a_sentence[:-1]))
    list_of_2_gram = map_dict.n_grams(list_of_token, 2)
    list_of_3_gram = map_dict.n_grams(list_of_token, 3)
    list_of_4_gram = map_dict.n_grams(list_of_token, 4)

    # map dictionary
    map_word = map_dict.zh_word if is_zh else map_dict.en_word
    map_phrase = map_dict.zh_phrase if is_zh else map_dict.en_phrase
    info_key = 'translation'

    list_of_info_for_word = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    # find the position that the corresponding word can be mapped with dictionary
    map_word_pos = map_dict.map_pos(list_of_info_for_word, 1)
    map_2_gram_pos = map_dict.map_pos(list_of_info_for_2_gram, 2)
    map_3_gram_pos = map_dict.map_pos(list_of_info_for_3_gram, 3)
    map_4_gram_pos = map_dict.map_pos(list_of_info_for_4_gram, 4)

    # if no map with dictionary
    if not map_word_pos and not map_2_gram_pos and not map_3_gram_pos and not map_4_gram_pos:
        return [], [], [], [], []

    # BPE for each word
    list_of_list_token_idx = list(
        map(lambda x: _tokenizer.encode(x), list_of_words_for_a_sentence))

    # get all words or phrases that can be mapped with dictionary
    if map_2_gram_pos or map_3_gram_pos or map_4_gram_pos:
        samples_to_be_selected = map_dict.merge_conflict_samples(
            len(list_of_words_for_a_sentence), map_4_gram_pos, map_3_gram_pos,
            map_2_gram_pos)

    else:
        samples_to_be_selected = map_dict.merge_conflict_samples(
            len(list_of_words_for_a_sentence), map_4_gram_pos, map_3_gram_pos,
            map_2_gram_pos, map_word_pos)

    mode = random.random()
    mode = 0 if mode <= ratio_mode_0 else (1 if mode <= ratio_mode_0_1 else 2)

    # only sample one word or phrase that can be mapped with dictionary
    sample = random.sample(samples_to_be_selected, 1)[0]
    samples = random.sample(
        samples_to_be_selected,
        random.randint(
            1,
            max(
                min(max_num, len(samples_to_be_selected),
                    int(len(list_of_words_for_a_sentence) * max_ratio)), 1)))
    samples.sort()

    samples_start, samples_end = list(zip(*samples))

    # get token index
    mask_idx = _tokenizer.vocab_size + Ids.mask
    sep_idx = _tokenizer.vocab_size + Ids.sep
    src_lan_idx = LanIds.zh if is_zh else LanIds.en
    tar_lan_idx = LanIds.en if is_zh else LanIds.zh

    _input = []
    _lan_input = []
    _output = []
    _lan_output = []
    _soft_pos_output = []

    # for mode 0, we only need one sample
    offset = _tokenizer.vocab_size + Ids.offset_ner
    ner_ids = get_ner_ids(sample, offset, list_of_list_token_idx)

    # for mode 1 and 2, we would need multiple samples
    ner_ids_list = [
        get_ner_ids(_sample, offset, list_of_list_token_idx)
        for _sample in samples
    ]

    if not ner_ids and not ner_ids_list:
        return [], [], [], [], []

    # mode = random.random()
    # mode = 0 if mode <= ratio_mode_0 else (1 if mode <= ratio_mode_0_1 else 2)
    # mode = random.randint(0, 2)

    start = _tokenizer.vocab_size + Ids.start_nmt
    end = _tokenizer.vocab_size + Ids.end_nmt

    # replace the masked word with <mask>, and
    #    let the ground truth be its corresponding NER
    if mode == 0:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index == sample[0]:
                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(sample[0], sample[1])
                ])
                pos_for_mask = [len(_input), len(_input) + len_tokens]
                if random.random() < keep_origin_rate:
                    _input += reduce(
                        lambda a, b: a + b,
                        list_of_list_token_idx[sample[0]:sample[1]])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [tar_lan_idx] * len_tokens
                index = sample[1]
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # translations_ids.sort()
        # new_translations_ids = list(map(lambda x: [sep_idx] + x, translations_ids[:3]))

        # get token idxs for output
        _output = ner_ids

        # get language index for output
        _lan_output = [LanIds.NER] * len(_output)

        # get soft position for output
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))
        _soft_pos_output = list(
            map(lambda a: int(round(a)),
                np.linspace(pos_for_mask[0], pos_for_mask[1], len(_output))))
        # _soft_pos_output = list(map(
        #     lambda x: list(map(lambda a: int(round(a)), np.linspace(pos_for_mask[0], pos_for_mask[1], len(x)))),
        #     new_translations_ids
        # ))
        # _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
        # _soft_pos_output[1] = _soft_pos_output[0]
        # _soft_pos_output.pop(0)

        start = _tokenizer.vocab_size + Ids.start_cdlm_ner_0
        end = _tokenizer.vocab_size + Ids.end_cdlm_ner_0

    # replace the masked word with <mask>, and
    #    let the ground truth be the original word
    if mode == 1:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index in samples_start:
                sample_idx = samples_start.index(index)
                sample_end = samples_end[sample_idx]

                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(index, sample_end)
                ])
                pos_for_mask.append([len(_input), len(_input) + len_tokens])

                if random.random() < keep_origin_rate:
                    _input += reduce(lambda a, b: a + b,
                                     list_of_list_token_idx[index:sample_end])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [src_lan_idx] * len_tokens

                index = sample_end
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [[
            list_of_list_token_idx[i] for i in range(_sample[0], _sample[1])
        ] for _sample in samples]
        _output = reduce(lambda a, b: a + b, _output)
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        _soft_pos_output = [
            list(range(_pos[0], _pos[1])) for _pos in pos_for_mask
        ]
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)

        start = _tokenizer.vocab_size + Ids.start_mlm
        end = _tokenizer.vocab_size + Ids.end_mlm

    # replace the masked word with its ner, and let the ground truth be its original word
    elif mode == 2:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index in samples_start:
                sample_idx = samples_start.index(index)
                sample_end = samples_end[sample_idx]

                _pos = [len(_input)]

                _input += ner_ids_list[sample_idx]

                _pos.append(len(_input))
                pos_for_mask.append(_pos)

                _lan_input += [LanIds.NER] * len(ner_ids_list[sample_idx])
                index = sample_end
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [[
            list_of_list_token_idx[i] for i in range(_sample[0], _sample[1])
        ] for _sample in samples]

        _soft_pos_output = [
            list(
                map(
                    lambda a: int(round(a)),
                    np.linspace(_pos[0], _pos[1],
                                len(reduce(lambda a, b: a + b, _output[i])))))
            for i, _pos in enumerate(pos_for_mask)
        ]
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
        # _soft_pos_output = list(map(lambda a: int(round(a)), _soft_pos_output))

        _output = reduce(lambda a, b: a + b, _output)
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        # _soft_pos_output =
        # _soft_pos_output = list(
        #     map(lambda a: int(round(a)), np.linspace(pos_for_mask[0], pos_for_mask[1], len(_output))))
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))

        start = _tokenizer.vocab_size + Ids.start_cdlm_ner_2
        end = _tokenizer.vocab_size + Ids.end_cdlm_ner_2

    # replace the masked word with its ner, let the ground truth be the tag of the source sequence;
    #   the tag value is 0, 1; 0 indicates it is not replaced word, 1 indicates it is a replaced word
    # elif mode == 3:
    #     pass

    # add <start> <end> token
    _input = [start] + _input + [end]
    _output = [start] + _output + [end]
    _lan_input = _lan_input[:1] + _lan_input + _lan_input[-1:]
    _lan_output = _lan_output[:1] + _lan_output + _lan_output[-1:]
    _soft_pos_output = _soft_pos_output[:
                                        1] + _soft_pos_output + _soft_pos_output[
                                            -1:]

    return _input, _output, _lan_input, _lan_output, _soft_pos_output
Exemplo n.º 2
0
def CDLM_translation(list_of_words_for_a_sentence,
                     _tokenizer,
                     is_zh,
                     keep_origin_rate=0.2):
    """

    :params
        list_of_words_for_a_sentence (list): ['I', 'am', 'a', 'student']
        tokenizer (object): tfds tokenizer object
        lan_index (int): index for language embeddings, could be 0 or 1
        min_num (int):
        max_num (int):
        max_ratio (float):
        keep_origin_rate (float):
        language (str): zh or en or both
    :returns
        masked_input (list): list of encoded and masked token idx
        list_of_tar_token_idx (list):
        list_of_lan_idx (list):
    """

    # get n grams
    list_of_token = list(
        map(lambda x: x.strip(), list_of_words_for_a_sentence[:-1]))
    list_of_2_gram = map_dict.n_grams(list_of_token, 2)
    list_of_3_gram = map_dict.n_grams(list_of_token, 3)
    list_of_4_gram = map_dict.n_grams(list_of_token, 4)

    # map dictionary
    map_word = map_dict.zh_word if is_zh else map_dict.en_word
    map_phrase = map_dict.zh_phrase if is_zh else map_dict.en_phrase
    info_key = 'translation'

    list_of_info_for_word = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    # find the position that the corresponding word can be mapped with dictionary
    map_word_pos = map_dict.map_pos(list_of_info_for_word, 1)
    map_2_gram_pos = map_dict.map_pos(list_of_info_for_2_gram, 2)
    map_3_gram_pos = map_dict.map_pos(list_of_info_for_3_gram, 3)
    map_4_gram_pos = map_dict.map_pos(list_of_info_for_4_gram, 4)

    # if no map with dictionary
    if not map_word_pos and not map_2_gram_pos and not map_3_gram_pos and not map_4_gram_pos:
        return [], [], [], [], []

    # BPE for each word
    list_of_list_token_idx = list(
        map(lambda x: _tokenizer.encode(x), list_of_words_for_a_sentence))

    # get all words or phrases that can be mapped with dictionary
    samples_to_be_selected = map_dict.merge_conflict_samples(
        len(list_of_words_for_a_sentence), map_4_gram_pos, map_3_gram_pos,
        map_2_gram_pos, map_word_pos)

    # only sample one word or phrase that can be mapped with dictionary
    sample = random.sample(samples_to_be_selected, 1)[0]

    # get token index
    mask_idx = _tokenizer.vocab_size + Ids.mask
    sep_idx = _tokenizer.vocab_size + Ids.sep
    src_lan_idx = LanIds.zh if is_zh else LanIds.en
    tar_lan_idx = LanIds.en if is_zh else LanIds.zh

    _input = []
    _lan_input = []
    _output = []
    _lan_output = []
    _soft_pos_output = []

    n = sample[1] - sample[0]
    if n == 4:
        translations = list_of_info_for_4_gram[sample[0]]
    elif n == 3:
        translations = list_of_info_for_3_gram[sample[0]]
    elif n == 2:
        translations = list_of_info_for_2_gram[sample[0]]
    else:
        translations = list_of_info_for_word[sample[0]]
        translations = list(
            filter(lambda x: x != list_of_words_for_a_sentence[sample[0]],
                   translations))

    # filter some noise of the translations
    translations = list(filter(lambda x: x, translations))
    if is_zh:
        translations = list(
            filter(lambda x: (x[0] != '-' or x[-1] != '-') and len(x) > 1,
                   translations))
        if len(translations) >= 8:
            tmp_translations = list(filter(lambda x: len(x) >= 5,
                                           translations))
            if tmp_translations:
                translations = tmp_translations
    if not translations:
        return [], [], [], [], []

    # apply BPE for the translations
    translations_ids = list(
        map(lambda x: _tokenizer.encode(x + ' '), translations))

    mode = random.random()
    mode = 0 if mode <= ratio_mode_0 else (1 if mode <= ratio_mode_0_1 else 2)
    # mode = random.randint(0, 2)

    start = _tokenizer.vocab_size + Ids.start_nmt
    end = _tokenizer.vocab_size + Ids.end_nmt

    # replace the masked word with <mask>, and
    #    let the ground truth be its corresponding translation
    if mode == 0:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index == sample[0]:
                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(sample[0], sample[1])
                ])
                pos_for_mask = [len(_input), len(_input) + len_tokens]
                if random.random() < keep_origin_rate:
                    _input += reduce(
                        lambda a, b: a + b,
                        list_of_list_token_idx[sample[0]:sample[1]])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [tar_lan_idx] * len_tokens
                index = sample[1]
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        translations_ids.sort()
        new_translations_ids = list(
            map(lambda x: [sep_idx] + x, translations_ids[:3]))

        # get token idxs for output
        _output = reduce(lambda a, b: a + b, new_translations_ids)
        _output.pop(0)

        # get language index for output
        _lan_output = [tar_lan_idx] * len(_output)

        # get soft position for output
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))
        _soft_pos_output = list(
            map(
                lambda x: list(
                    map(lambda a: int(round(a)),
                        np.linspace(pos_for_mask[0], pos_for_mask[1], len(x)))
                ), new_translations_ids))
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
        # _soft_pos_output[1] = _soft_pos_output[0]
        # _soft_pos_output.pop(0)

        start = _tokenizer.vocab_size + Ids.start_cdlm_t_0
        end = _tokenizer.vocab_size + Ids.end_cdlm_t_0

    # replace the masked word with <mask>, and
    #    let the ground truth be the original word
    if mode == 1:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index == sample[0]:
                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(sample[0], sample[1])
                ])
                pos_for_mask = [len(_input), len(_input) + len_tokens]
                if random.random() < keep_origin_rate:
                    _input += reduce(
                        lambda a, b: a + b,
                        list_of_list_token_idx[sample[0]:sample[1]])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [src_lan_idx] * len_tokens
                index = sample[1]
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [
            list_of_list_token_idx[i] for i in range(sample[0], sample[1])
        ]
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        _soft_pos_output = list(range(*pos_for_mask))

        start = _tokenizer.vocab_size + Ids.start_mlm
        end = _tokenizer.vocab_size + Ids.end_mlm

    # replace the masked word with its translation, and let the ground truth be its original word
    elif mode == 2:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index == sample[0]:
                pos_for_mask = [len(_input)]

                tmp_input = random.sample(translations_ids, 1)[0]
                _input += tmp_input

                pos_for_mask.append(len(_input))

                _lan_input += [tar_lan_idx] * len(tmp_input)
                index = sample[1]
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [
            list_of_list_token_idx[i] for i in range(sample[0], sample[1])
        ]
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        _soft_pos_output = list(
            map(lambda a: int(round(a)),
                np.linspace(pos_for_mask[0], pos_for_mask[1], len(_output))))
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))

        start = _tokenizer.vocab_size + Ids.start_cdlm_t_2
        end = _tokenizer.vocab_size + Ids.end_cdlm_t_2

    # replace the masked word with its translation, let the ground truth be the tag of the source sequence;
    #   the tag value is 0, 1; 0 indicates it is not replaced word, 1 indicates it is a replaced word
    # elif mode == 3:
    #     pass

    # add <start> <end> token
    _input = [start] + _input + [end]
    _output = [start] + _output + [end]
    _lan_input = _lan_input[:1] + _lan_input + _lan_input[-1:]
    _lan_output = _lan_output[:1] + _lan_output + _lan_output[-1:]
    _soft_pos_output = _soft_pos_output[:
                                        1] + _soft_pos_output + _soft_pos_output[
                                            -1:]

    return _input, _output, _lan_input, _lan_output, _soft_pos_output
Exemplo n.º 3
0
def CDLM_definition(list_of_words_for_a_sentence,
                    _tokenizer,
                    is_zh,
                    keep_origin_rate=0.2,
                    max_ratio=0.2,
                    max_num=4):
    """

    :params
        list_of_words_for_a_sentence (list): ['I', 'am', 'a', 'student']
        tokenizer (object): tfds tokenizer object
        lan_index (int): index for language embeddings, could be 0 or 1
        min_num (int):
        max_num (int):
        max_ratio (float):
        keep_origin_rate (float):
        language (str): zh or en or both
    :returns
        masked_input (list): list of encoded and masked token idx
        list_of_tar_token_idx (list):
        list_of_lan_idx (list):
    """

    # get n grams
    list_of_token = list(
        map(lambda x: x.strip(), list_of_words_for_a_sentence[:-1]))
    list_of_2_gram = map_dict.n_grams(list_of_token, 2)
    list_of_3_gram = map_dict.n_grams(list_of_token, 3)
    list_of_4_gram = map_dict.n_grams(list_of_token, 4)

    # map dictionary
    map_word = map_dict.zh_word if is_zh else map_dict.en_word
    map_phrase = map_dict.zh_phrase if is_zh else map_dict.en_phrase

    info_key = 'src_meanings'
    list_of_info_for_word_src = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    info_key = 'tar_meanings'
    list_of_info_for_word_tar = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    # find the position that the corresponding word can be mapped with dictionary
    map_word_pos_src = map_dict.map_pos(list_of_info_for_word_src, 1)
    map_2_gram_pos_src = map_dict.map_pos(list_of_info_for_2_gram_src, 2)
    map_3_gram_pos_src = map_dict.map_pos(list_of_info_for_3_gram_src, 3)
    map_4_gram_pos_src = map_dict.map_pos(list_of_info_for_4_gram_src, 4)

    map_word_pos_tar = map_dict.map_pos(list_of_info_for_word_tar, 1)
    map_2_gram_pos_tar = map_dict.map_pos(list_of_info_for_2_gram_tar, 2)
    map_3_gram_pos_tar = map_dict.map_pos(list_of_info_for_3_gram_tar, 3)
    map_4_gram_pos_tar = map_dict.map_pos(list_of_info_for_4_gram_tar, 4)

    no_src_syn = not map_word_pos_src and not map_2_gram_pos_src and not map_3_gram_pos_src and not map_4_gram_pos_src
    no_tar_syn = not map_word_pos_tar and not map_2_gram_pos_tar and not map_3_gram_pos_tar and not map_4_gram_pos_tar

    if no_src_syn and no_tar_syn:
        return [], [], [], [], []

    if no_src_syn:
        use_tar = True
    elif no_tar_syn:
        use_tar = False
    else:
        use_tar = True if random.random() <= 0.5 else False

    list_of_info_for_word = list_of_info_for_word_tar if use_tar else list_of_info_for_word_src
    list_of_info_for_2_gram = list_of_info_for_2_gram_tar if use_tar else list_of_info_for_2_gram_src
    list_of_info_for_3_gram = list_of_info_for_3_gram_tar if use_tar else list_of_info_for_3_gram_src
    list_of_info_for_4_gram = list_of_info_for_4_gram_tar if use_tar else list_of_info_for_4_gram_src

    map_word_pos = map_word_pos_tar if use_tar else map_word_pos_src
    map_2_gram_pos = map_2_gram_pos_tar if use_tar else map_2_gram_pos_src
    map_3_gram_pos = map_3_gram_pos_tar if use_tar else map_3_gram_pos_src
    map_4_gram_pos = map_4_gram_pos_tar if use_tar else map_4_gram_pos_src

    # BPE for each word
    list_of_list_token_idx = list(
        map(lambda x: _tokenizer.encode(x), list_of_words_for_a_sentence))

    # get all words or phrases that can be mapped with dictionary
    samples_to_be_selected = map_dict.merge_conflict_samples(
        len(list_of_words_for_a_sentence), map_4_gram_pos, map_3_gram_pos,
        map_2_gram_pos, map_word_pos)

    # only sample one word or phrase that can be mapped with dictionary
    sample = random.sample(samples_to_be_selected, 1)[0]

    # get token index
    mask_idx = _tokenizer.vocab_size + Ids.mask
    sep_idx = _tokenizer.vocab_size + Ids.sep
    src_lan_idx = LanIds.zh if is_zh else LanIds.en
    tar_lan_idx = LanIds.en if is_zh else LanIds.zh
    lan_idx = tar_lan_idx if use_tar else src_lan_idx

    _input = []
    _lan_input = []
    _output = []
    _lan_output = []
    _soft_pos_output = []

    # for mode 0, we only need one sample
    definitions = get_definitions(sample, list_of_info_for_4_gram,
                                  list_of_info_for_3_gram,
                                  list_of_info_for_2_gram,
                                  list_of_info_for_word,
                                  list_of_words_for_a_sentence, is_zh)

    if not definitions:
        return [], [], [], [], []

    # apply BPE for the definitions
    definitions_ids = list(
        map(lambda x: _tokenizer.encode(x + ' '), definitions))

    # replace the masked word with <mask>, and
    #    let the ground truth be its corresponding definition

    index = 0
    len_words = len(list_of_list_token_idx)
    pos_for_mask = []

    while index < len_words:
        if index == sample[0]:
            len_tokens = sum([
                len(list_of_list_token_idx[i])
                for i in range(sample[0], sample[1])
            ])
            pos_for_mask = [len(_input), len(_input) + len_tokens]
            if random.random() < keep_origin_rate:
                _input += reduce(lambda a, b: a + b,
                                 list_of_list_token_idx[sample[0]:sample[1]])
            else:
                _input += [mask_idx] * len_tokens
            _lan_input += [tar_lan_idx] * len_tokens
            index = sample[1]
            continue

        _input += list_of_list_token_idx[index]
        _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
        index += 1

    definitions_ids.sort()
    definitions_ids = list(map(lambda x: [sep_idx] + x, definitions_ids[:2]))
    new_definitions_ids = copy.deepcopy(definitions_ids)

    # get token idxs for output
    _output = reduce(lambda a, b: a + b, new_definitions_ids)
    _output.pop(0)

    # get language index for output
    _lan_output = [lan_idx] * len(_output)

    # get soft position for output
    # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))
    _soft_pos_output = list(
        map(
            lambda x: list(
                map(lambda a: int(round(a)),
                    np.linspace(pos_for_mask[0], pos_for_mask[1], len(x)))),
            definitions_ids))
    _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
    _soft_pos_output[1] = _soft_pos_output[0]
    _soft_pos_output.pop(0)

    start = _tokenizer.vocab_size + Ids.start_cdlm_def
    end = _tokenizer.vocab_size + Ids.end_cdlm_def

    # replace the masked word with its definition, let the ground truth be the tag of the source sequence;
    #   the tag value is 0, 1; 0 indicates it is not replaced word, 1 indicates it is a replaced word
    # elif mode == 3:
    #     pass

    # add <start> <end> token
    _input = [start] + _input + [end]
    _output = [start] + _output + [end]
    _lan_input = _lan_input[:1] + _lan_input + _lan_input[-1:]
    _lan_output = _lan_output[:1] + _lan_output + _lan_output[-1:]
    _soft_pos_output = _soft_pos_output[:
                                        1] + _soft_pos_output + _soft_pos_output[
                                            -1:]

    return _input, _output, _lan_input, _lan_output, _soft_pos_output
Exemplo n.º 4
0
def CDLM_synonym(list_of_words_for_a_sentence,
                 _tokenizer,
                 is_zh,
                 keep_origin_rate=0.2,
                 max_ratio=0.2,
                 max_num=4):
    """

    :params
        list_of_words_for_a_sentence (list): ['I', 'am', 'a', 'student']
        tokenizer (object): tfds tokenizer object
        lan_index (int): index for language embeddings, could be 0 or 1
        min_num (int):
        max_num (int):
        max_ratio (float):
        keep_origin_rate (float):
        language (str): zh or en or both
    :returns
        masked_input (list): list of encoded and masked token idx
        list_of_tar_token_idx (list):
        list_of_lan_idx (list):
    """

    # get n grams
    list_of_token = list(
        map(lambda x: x.strip(), list_of_words_for_a_sentence[:-1]))
    list_of_2_gram = map_dict.n_grams(list_of_token, 2)
    list_of_3_gram = map_dict.n_grams(list_of_token, 3)
    list_of_4_gram = map_dict.n_grams(list_of_token, 4)

    # map dictionary
    map_word = map_dict.zh_word if is_zh else map_dict.en_word
    map_phrase = map_dict.zh_phrase if is_zh else map_dict.en_phrase

    info_key = 'src_synonyms'
    list_of_info_for_word_src = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram_src = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    info_key = 'tar_synonyms'
    list_of_info_for_word_tar = list(
        map(lambda x: map_word(x, info_key), list_of_token))
    list_of_info_for_2_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_2_gram))
    list_of_info_for_3_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_3_gram))
    list_of_info_for_4_gram_tar = list(
        map(lambda x: map_phrase(x, info_key), list_of_4_gram))

    # find the position that the corresponding word can be mapped with dictionary
    map_word_pos_src = map_dict.map_pos(list_of_info_for_word_src, 1)
    map_2_gram_pos_src = map_dict.map_pos(list_of_info_for_2_gram_src, 2)
    map_3_gram_pos_src = map_dict.map_pos(list_of_info_for_3_gram_src, 3)
    map_4_gram_pos_src = map_dict.map_pos(list_of_info_for_4_gram_src, 4)

    map_word_pos_tar = map_dict.map_pos(list_of_info_for_word_tar, 1)
    map_2_gram_pos_tar = map_dict.map_pos(list_of_info_for_2_gram_tar, 2)
    map_3_gram_pos_tar = map_dict.map_pos(list_of_info_for_3_gram_tar, 3)
    map_4_gram_pos_tar = map_dict.map_pos(list_of_info_for_4_gram_tar, 4)

    no_src_syn = not map_word_pos_src and not map_2_gram_pos_src and not map_3_gram_pos_src and not map_4_gram_pos_src
    no_tar_syn = not map_word_pos_tar and not map_2_gram_pos_tar and not map_3_gram_pos_tar and not map_4_gram_pos_tar

    if no_src_syn and no_tar_syn:
        return [], [], [], [], []

    if no_src_syn:
        use_tar = True
    elif no_tar_syn:
        use_tar = False
    else:
        use_tar = True if random.random() <= 0.5 else False

    list_of_info_for_word = list_of_info_for_word_tar if use_tar else list_of_info_for_word_src
    list_of_info_for_2_gram = list_of_info_for_2_gram_tar if use_tar else list_of_info_for_2_gram_src
    list_of_info_for_3_gram = list_of_info_for_3_gram_tar if use_tar else list_of_info_for_3_gram_src
    list_of_info_for_4_gram = list_of_info_for_4_gram_tar if use_tar else list_of_info_for_4_gram_src

    map_word_pos = map_word_pos_tar if use_tar else map_word_pos_src
    map_2_gram_pos = map_2_gram_pos_tar if use_tar else map_2_gram_pos_src
    map_3_gram_pos = map_3_gram_pos_tar if use_tar else map_3_gram_pos_src
    map_4_gram_pos = map_4_gram_pos_tar if use_tar else map_4_gram_pos_src

    # BPE for each word
    list_of_list_token_idx = list(
        map(lambda x: _tokenizer.encode(x), list_of_words_for_a_sentence))

    # get all words or phrases that can be mapped with dictionary
    samples_to_be_selected = map_dict.merge_conflict_samples(
        len(list_of_words_for_a_sentence), map_4_gram_pos, map_3_gram_pos,
        map_2_gram_pos, map_word_pos)

    mode = random.random()
    mode = 0 if mode <= ratio_mode_0 else (1 if mode <= ratio_mode_0_1 else 2)

    # only sample one word or phrase that can be mapped with dictionary
    sample = random.sample(samples_to_be_selected, 1)[0]
    samples = random.sample(
        samples_to_be_selected,
        random.randint(
            1,
            max(
                min(max_num, len(samples_to_be_selected),
                    int(len(list_of_words_for_a_sentence) * max_ratio)), 1)))
    samples.sort()

    # get token index
    mask_idx = _tokenizer.vocab_size + Ids.mask
    sep_idx = _tokenizer.vocab_size + Ids.sep
    src_lan_idx = LanIds.zh if is_zh else LanIds.en
    tar_lan_idx = LanIds.en if is_zh else LanIds.zh
    lan_idx = tar_lan_idx if use_tar else src_lan_idx

    _input = []
    _lan_input = []
    _output = []
    _lan_output = []
    _soft_pos_output = []

    # for mode 0, we only need one sample
    synonyms = get_synonyms(sample, list_of_info_for_4_gram,
                            list_of_info_for_3_gram, list_of_info_for_2_gram,
                            list_of_info_for_word,
                            list_of_words_for_a_sentence, is_zh)

    # for mode 1 and 2, we would need multiple samples
    synonyms_list = [
        get_synonyms(_sample, list_of_info_for_4_gram, list_of_info_for_3_gram,
                     list_of_info_for_2_gram, list_of_info_for_word,
                     list_of_words_for_a_sentence, is_zh)
        for _sample in samples
    ]

    # remove those do not have synonym info samples after filtering
    delete_samples = []
    for i, v in enumerate(synonyms_list):
        if not v:
            delete_samples.append(i)

    delete_samples.sort(reverse=True)
    for k in delete_samples:
        del synonyms_list[k]
        del samples[k]

    if not synonyms and not synonyms_list:
        return [], [], [], [], []

    if samples:
        samples_start, samples_end = list(zip(*samples))
    else:
        samples_start = samples_end = []

    # apply BPE for the synonyms
    synonyms_ids = list(map(lambda x: _tokenizer.encode(x + ' '), synonyms))
    synonyms_ids_list = list(
        map(
            lambda _synonyms: list(
                map(lambda x: _tokenizer.encode(x + ' '), _synonyms)),
            synonyms_list))

    if not synonyms_list and synonyms:
        mode = 0
    elif not synonyms:
        mode = random.random()
        if mode < (ratio_mode_1 / (ratio_mode_1 + ratio_mode_2)):
            mode = 1
        else:
            mode = 2

    # mode = random.random()
    # mode = 0 if mode <= ratio_mode_0 else (1 if mode <= ratio_mode_0_1 else 2)
    # mode = random.randint(0, 2)

    start = _tokenizer.vocab_size + Ids.start_nmt
    end = _tokenizer.vocab_size + Ids.end_nmt

    # replace the masked word with <mask>, and
    #    let the ground truth be its corresponding synonym
    if mode == 0:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index == sample[0]:
                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(sample[0], sample[1])
                ])
                pos_for_mask = [len(_input), len(_input) + len_tokens]
                if random.random() < keep_origin_rate:
                    _input += reduce(
                        lambda a, b: a + b,
                        list_of_list_token_idx[sample[0]:sample[1]])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [tar_lan_idx] * len_tokens
                index = sample[1]
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        synonyms_ids.sort()
        synonyms_ids = list(map(lambda x: [sep_idx] + x, synonyms_ids[:3]))
        new_synonyms_ids = copy.deepcopy(synonyms_ids)

        # get token idxs for output
        _output = reduce(lambda a, b: a + b, new_synonyms_ids)
        _output.pop(0)

        # get language index for output
        _lan_output = [lan_idx] * len(_output)

        # get soft position for output
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))
        _soft_pos_output = list(
            map(
                lambda x: list(
                    map(lambda a: int(round(a)),
                        np.linspace(pos_for_mask[0], pos_for_mask[1], len(x)))
                ), synonyms_ids))
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
        _soft_pos_output[1] = _soft_pos_output[0]
        _soft_pos_output.pop(0)

        start = _tokenizer.vocab_size + Ids.start_cdlm_synonym_0
        end = _tokenizer.vocab_size + Ids.end_cdlm_synonym_0

    # replace the masked word with <mask>, and
    #    let the ground truth be the original word
    if mode == 1:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index in samples_start:
                sample_idx = samples_start.index(index)
                sample_end = samples_end[sample_idx]

                len_tokens = sum([
                    len(list_of_list_token_idx[i])
                    for i in range(index, sample_end)
                ])
                pos_for_mask.append([len(_input), len(_input) + len_tokens])

                if random.random() < keep_origin_rate:
                    _input += reduce(lambda a, b: a + b,
                                     list_of_list_token_idx[index:sample_end])
                else:
                    _input += [mask_idx] * len_tokens
                _lan_input += [src_lan_idx] * len_tokens

                index = sample_end
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [[
            list_of_list_token_idx[i] for i in range(_sample[0], _sample[1])
        ] for _sample in samples]
        _output = reduce(lambda a, b: a + b, _output)
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        _soft_pos_output = [
            list(range(_pos[0], _pos[1])) for _pos in pos_for_mask
        ]
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)

        start = _tokenizer.vocab_size + Ids.start_mlm
        end = _tokenizer.vocab_size + Ids.end_mlm

    # replace the masked word with its synonym, and let the ground truth be its original word
    elif mode == 2:

        index = 0
        len_words = len(list_of_list_token_idx)
        pos_for_mask = []

        while index < len_words:
            if index in samples_start:
                sample_idx = samples_start.index(index)
                sample_end = samples_end[sample_idx]

                _pos = [len(_input)]

                tmp_input = random.sample(synonyms_ids_list[sample_idx], 1)[0]
                _input += tmp_input

                _pos.append(len(_input))
                pos_for_mask.append(_pos)

                _lan_input += [lan_idx] * len(tmp_input)
                index = sample_end
                continue

            _input += list_of_list_token_idx[index]
            _lan_input += [src_lan_idx] * len(list_of_list_token_idx[index])
            index += 1

        # get token idxs for output
        _output = [[
            list_of_list_token_idx[i] for i in range(_sample[0], _sample[1])
        ] for _sample in samples]

        _soft_pos_output = [
            list(
                map(
                    lambda a: int(round(a)),
                    np.linspace(_pos[0], _pos[1],
                                len(reduce(lambda a, b: a + b, _output[i])))))
            for i, _pos in enumerate(pos_for_mask)
        ]
        _soft_pos_output = reduce(lambda a, b: a + b, _soft_pos_output)
        # _soft_pos_output = list(map(lambda a: int(round(a)), _soft_pos_output))

        _output = reduce(lambda a, b: a + b, _output)
        _output = reduce(lambda a, b: a + b, _output)

        # get language index for output
        _lan_output = [src_lan_idx] * len(_output)

        # get soft position for output
        # _soft_pos_output =
        # _soft_pos_output = list(
        #     map(lambda a: int(round(a)), np.linspace(pos_for_mask[0], pos_for_mask[1], len(_output))))
        # _soft_pos_output = [pos_for_mask[0]] * int(len(_output))

        start = _tokenizer.vocab_size + Ids.start_cdlm_synonym_2
        end = _tokenizer.vocab_size + Ids.end_cdlm_synonym_2

    # replace the masked word with its synonym, let the ground truth be the tag of the source sequence;
    #   the tag value is 0, 1; 0 indicates it is not replaced word, 1 indicates it is a replaced word
    # elif mode == 3:
    #     pass

    # add <start> <end> token
    _input = [start] + _input + [end]
    _output = [start] + _output + [end]
    _lan_input = _lan_input[:1] + _lan_input + _lan_input[-1:]
    _lan_output = _lan_output[:1] + _lan_output + _lan_output[-1:]
    _soft_pos_output = _soft_pos_output[:
                                        1] + _soft_pos_output + _soft_pos_output[
                                            -1:]

    return _input, _output, _lan_input, _lan_output, _soft_pos_output