def CDLM_combine_sample(list_of_zh_words, list_of_en_words, _tokenizer, keep_origin_rate=0.2, TLM_ratio=0.7, max_ratio=0.2, max_num=4): data = [] for i in range(len(list_of_zh_words)): list_of_zh_word = list_of_zh_words[i] list_of_en_word = list_of_en_words[i] # if TLM sample if random.random() < TLM_ratio: _inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs = CDLM_pos_for_zh_en( list_of_zh_word, list_of_en_word, _tokenizer, keep_origin_rate, max_ratio, max_num ) if not _inputs[0] or not _inputs[1] or not _soft_pos_outputs[0] or not _soft_pos_outputs[1]: continue data.append(TLM_concat(_inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs)) # MLM sample else: zh_input, zh_output, zh_lan_input, zh_lan_output, zh_soft_pos_output = CDLM_pos( list_of_zh_word, _tokenizer, True, keep_origin_rate, max_ratio, max_num) en_input, en_output, en_lan_input, en_lan_output, en_soft_pos_output = CDLM_pos( list_of_en_word, _tokenizer, False, keep_origin_rate, max_ratio, max_num) if zh_input and zh_output and zh_lan_input and zh_lan_output and zh_soft_pos_output: data.append([zh_input, zh_output, zh_lan_input, zh_lan_output, zh_soft_pos_output]) if en_input and en_output and en_lan_input and en_lan_output and en_soft_pos_output: data.append([en_input, en_output, en_lan_input, en_lan_output, en_soft_pos_output]) _inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs = list(zip(*data)) return _inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs
def CDLM_TLM_sample(list_of_zh_words, list_of_en_words, _tokenizer, keep_origin_rate=0.2, max_ratio=0.2, max_num=4): data = list(zip(list_of_zh_words, list_of_en_words)) data = list(map( lambda x: CDLM_pos_for_zh_en(x[0], x[1], _tokenizer, keep_origin_rate, max_ratio, max_num), data)) data = list(filter(lambda x: x[0][0] and x[0][1], data)) data = list(map(lambda x: TLM_concat(*x), data)) random.seed(random_state) random.shuffle(data) _inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs = list(zip(*data)) return _inputs, _outputs, _lan_inputs, _lan_outputs, _soft_pos_outputs