def main(): max_sentence_num = 100 src_path = "/data/rrjin/NMT/data/bible_data/corpus/train_src_combine_joint_bpe_22000.txt" tgt_path = "/data/rrjin/NMT/data/bible_data/corpus/train_tgt_en_joint_bpe_22000.txt" src_data = read_data(src_path) tgt_data = read_data(tgt_path) src_freq = frequency(src_data) tgt_freq = frequency(tgt_data) src_freq = sort_freq(src_freq) tgt_freq = sort_freq(tgt_freq) _, src_file_name = os.path.split(src_path) _, tgt_file_name = os.path.split(tgt_path) plot_freq(src_freq, "statistic of " + src_file_name + ".jpg") plot_freq(tgt_freq, "statistic of " + tgt_file_name + ".jpg") sentence_number = len(src_data) cnt = 0 for length, freq in src_freq: if length <= max_sentence_num: cnt += freq print(cnt) print(sentence_number)
def remove_long_sentence(src_file_path_list: list, tgt_file_path_list: list, max_sentence_length: int): assert len(src_file_path_list) == len(tgt_file_path_list) for src_file_path, tgt_file_path in zip(src_file_path_list, tgt_file_path_list): src_data = read_data(src_file_path) tgt_data = read_data(tgt_file_path) src_sentence_filtered = [] tgt_sentence_filtered = [] for src_sentence, tgt_sentence in zip(src_data, tgt_data): if len(src_sentence.split()) <= max_sentence_length and len( tgt_sentence.split()) <= max_sentence_length: src_sentence_filtered.append(src_sentence) tgt_sentence_filtered.append(tgt_sentence) new_src_file_path = get_file_path(src_file_path) new_tgt_file_path = get_file_path(tgt_file_path) write_data(src_sentence_filtered, new_src_file_path) write_data(tgt_sentence_filtered, new_tgt_file_path)
def remove_long_sentence(args): assert args.max_sentence_length is not None for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip( args.src_file_path_list, args.tgt_file_path_list, args.output_src_file_path_list, args.output_tgt_file_path_list): src_data = read_data(src_file_path) tgt_data = read_data(tgt_file_path) assert len(src_data) == len(tgt_data) src_sentence_filtered = [] tgt_sentence_filtered = [] for src_sentence, tgt_sentence in zip(src_data, tgt_data): if len(src_sentence.split()) <= args.max_sentence_length and \ len(tgt_sentence.split()) <= args.max_sentence_length: src_sentence_filtered.append(src_sentence) tgt_sentence_filtered.append(tgt_sentence) write_data(src_sentence_filtered, filtered_src_file_path) write_data(tgt_sentence_filtered, filtered_tgt_file_path)
def value_linkage_attack(path1, path2, domain_path, attr, value, sample_size=100): print('value linkage attack') # 从 path1 中找到几个 record 其 attr 值为 value # 从 path2 中找到 record,他们的 common attr 应当等于 之前的 record. 并且它们的 attr 为同一个 value # 这样就可以建立未知属性的 value 映射 # 首先计算出 attr 的 marginal,对映射再次核查 data1, domain, columns = tools.read_data(path1, domain_path) data2, domain, columns = tools.read_data(path2, domain_path, dtype=None) # print(data1[324], data2[324]) data_num1 = len(data1) data_num2 = len(data2) temp_data1 = data1[:, anchor_attr] temp_data2 = data2[:, anchor_attr].astype(int) # print(temp_data1[324], temp_data2[324]) print(f' finding uqniue values') unique, unique_indices1, unique_counts = np.unique(temp_data1, axis=0, return_index=True, return_counts=True) unique_map1 = {tuple(unique[i]): (unique_counts[i], unique_indices1[i]) for i in range(len(unique))} print(f' found unique value {len(unique)}') unique, unique_indices2, unique_counts = np.unique(temp_data2, axis=0, return_index=True, return_counts=True) unique_map2 = {tuple(unique[i]): (unique_counts[i], unique_indices2[i]) for i in range(len(unique))} print(f' found unique value {len(unique)}') with open('./info/unique_line_idx.txt', 'w') as out_file: print('generating voting map') map_vote = {} for value, cnt_idx1 in unique_map1.items(): cnt1, idx1 = cnt_idx1 if cnt1 == 1 and value in unique_map2: cnt2, idx2 = unique_map2[value] if cnt2 == 1: out_file.write(f'{idx1} {idx2}\n') attr_value1 = str(data1[idx1, attr]) attr_value2 = str(data2[idx2, attr]) if attr_value1 not in map_vote: map_vote[attr_value1] = {} if attr_value2 not in map_vote[attr_value1]: map_vote[attr_value1][attr_value2] = 0 map_vote[attr_value1][attr_value2] += 1 print(len(map_vote)) json.dump(map_vote, open('./info/map_vote.json', 'w')) for key, votes in map_vote.items(): if len(votes) == 1: print(key, list(votes.items())) else: print(' ', key, list(votes.items()))
def __init__(self, src_file_path: str, tgt_file_path: str, seed: int): self.src_data = read_data(src_file_path) self.tgt_data = read_data(tgt_file_path) self.seed = seed assert len(self.src_data) == len(self.tgt_data) self.shuffled_src_data = [] self.shuffled_tgt_data = []
def check_unique_line(path1, path2, print_num=10, check_num=10000, attr=None): unique_idx = [] with open('./info/unique_line_idx.txt', 'r') as in_file: for line in in_file: unique_idx.append([int(item) for item in line.split(' ')]) random.shuffle(unique_idx) data1, domain, columns = tools.read_data(path1, domain_path) data2, domain, columns = tools.read_data(path2, domain_path, dtype=None) # print(data1[324], data2[324]) # data3, columns3 = tools.read_csv('./data/2019data/taxi_trips_2019.csv') data_num1 = len(data1) data_num2 = len(data2) # print(data1[4924570]) # print(data2[9665075]) for i in range(print_num): if attr == None: print(data1[unique_idx[i][0]]) print(data2[unique_idx[i][1]]) else: print(data1[unique_idx[i][0]][attr]) print(data2[unique_idx[i][1]][attr]) # print(data3[unique_idx[i][1]][2]) # print(data3[unique_idx[i][1]][3]) print('') pass_flag = True for i in range(check_num): if attr == None: value1 = data1[unique_idx[i][0]][1:] value2 = data2[unique_idx[i][1]][1:] else: value1 = data1[unique_idx[i][0]][attr] value2 = data2[unique_idx[i][1]][attr] if not (value1 == value2).all(): pass_flag = False print(i, unique_idx[i]) print(value1) print(value2) break print(pass_flag)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_path", nargs="+") parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--strip_accents", action="store_true") parser.add_argument("--tokenize_chinese_chars", action="store_true") args, unknown = parser.parse_known_args() print("do_lower_case: {}, strip_accents: {}, tokenize_chinese_chars: {}". format(args.do_lower_case, args.strip_accents, args.tokenize_chinese_chars)) tokenizer = BasicTokenizer( do_lower_case=args.do_lower_case, strip_accents=args.strip_accents, tokenize_chinese_chars=args.tokenize_chinese_chars) for data_path in args.data_path: data = read_data(data_path) tok_data = [] if data_path.endswith(".en"): tok_data = [tokenizer.tokenize(sentence) for sentence in data] else: for sentence in data: sentence = sentence.strip() first_blank_pos = sentence.find(" ") if first_blank_pos != -1: # do not process language identify token lang_identify_token = sentence[:first_blank_pos] sentence_after_process = tokenizer.tokenize( sentence[first_blank_pos + 1:]) sentence_after_process = " ".join( [lang_identify_token, sentence_after_process]) tok_data.append(sentence_after_process) else: tok_data.append(sentence) idx = data_path.rfind(".") assert idx != -1 tok_data_path = data_path[:idx] + "_tok" + data_path[idx:] print("{} > {}".format(data_path, tok_data_path)) write_data(tok_data, tok_data_path)
def preprocess(input_path, output_path): print(f'preprocesing {input_path} -> {output_path}') data, domain, columns = tools.read_data(input_path, domain_path) attr_num = data.shape[1] print(data.shape) for attr in range(attr_num): if columns[attr] in ['pickup_community_area', 'dropoff_community_area']: print(attr, columns[attr], np.min(data[:, attr])) mask = data[:, attr] == -1 data[:, attr][mask] = 0 elif columns[attr] in ['fare', 'tips', 'trip_total', 'trip_seconds', 'trip_miles']: print(attr, columns[attr], np.min(data[:, attr])) data[:, attr] += 1 elif columns[attr] == 'payment_type': # caution: 4 is not a valid value of payment_type in parameters.json print(attr, columns[attr], np.sum(data[:, attr] == 4)) mask = data[:, attr] == -1 data[:, attr][mask] = 4 # print(np.min(data)) print(columns) data = list(data) with open(output_path, 'w') as output_file: writer = csv.writer(output_file) writer.writerow(columns) for line in data: writer.writerow(line)
def convert(language_code_path: str, output_file_path: str): language_code = read_data(language_code_path) language_dict = OrderedDict() for code in language_code: alpha = {} code = code[1:-1] if len(code) == 2: alpha["alpha_2"] = code elif len(code) == 3: alpha["alpha_3"] = code else: language_dict[code] = {"ISO639-3": "unknown", "language name": "unknown"} continue language_data = pycountry.languages.get(**alpha) if language_data is None: language_dict[code] = {"ISO639-3": "unknown", "language name": "unknown"} else: language_dict[code] = {"ISO639-3": language_data.alpha_3, "language name": language_data.name} with open(output_file_path, "w") as f: json.dump(language_dict, f)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--file_path", "-f", required=True) args, unknown = parser.parse_known_args() data = read_data(args.file_path) directory, file_name = os.path.split(args.file_path) data_after_processing = [] for line in data: line = "".join(c for c in line if c != "@") data_after_processing.append(line) idx = file_name.rfind(".") if idx == -1: new_file_name = file_name + "_remove_at" else: new_file_name = file_name[:idx] + "_remove_at" + file_name[idx:] write_data(data_after_processing, os.path.join(directory, new_file_name))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_path", required=True) parser.add_argument("--picture_path", required=True) parser.add_argument("--language_data", required=True) args, unknown = parser.parse_known_args() with open(args.language_data) as f: language_dict = json.load(f) data = read_data(args.data_path) sentence_num_per_language = {} for sentence in data: token_list = sentence.split() lang_code = token_list[0] assert lang_code.startswith("<") and lang_code.endswith(">") lang_code = lang_code[1:-1] sentence_num_per_language[lang_code] = sentence_num_per_language.get(lang_code, 0) + 1 sentence_num_per_language = {k: v for k, v in sorted(sentence_num_per_language.items(), key=cmp)} plot_data(sentence_num_per_language, language_dict, args.picture_path)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--feature_name", required=True) parser.add_argument("--classify_method", required=True, choices=["svm", "logistic"]) parser.add_argument("--lang_vec_path", required=True) parser.add_argument("--lang_name_path", required=True) parser.add_argument("--output_file_path", required=True) args, unknown = parser.parse_known_args() lang_name = read_data(args.lang_name_path) lang_vec = load_lang_vec(args.lang_vec_path) lang_alpha3 = {} for lang in lang_name: alpha3 = get_language_alpha3(lang[1:-1]) if check_alpha3(alpha3): lang_alpha3[lang] = alpha3 feature_name = args.feature_name features = l2v.get_features(list(lang_alpha3.values()), feature_name, header=True) train(args, lang_vec, lang_alpha3, features)
def remove_same_sentence(args): if args.src_memory_path: src_memory_list = [] for memory_path in args.src_memory_path: src_memory_list.extend(read_data(memory_path)) src_memory = set(src_memory_list) else: src_memory = set() for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip( args.src_file_path_list, args.tgt_file_path_list, args.output_src_file_path_list, args.output_tgt_file_path_list): src_data = read_data(src_file_path) tgt_data = read_data(tgt_file_path) assert len(src_data) == len(tgt_data) src_sentence_filtered = [] tgt_sentence_filtered = [] sentence_visited = set() removed_sentence_id = set() for i, sentence in enumerate(src_data): if sentence in sentence_visited: removed_sentence_id.add(i) elif sentence in src_memory: removed_sentence_id.add(i) sentence_visited.add(sentence) else: sentence_visited.add(sentence) for i, (src_sentence, tgt_sentence) in enumerate(zip(src_data, tgt_data)): if i not in removed_sentence_id: src_sentence_filtered.append(src_sentence) tgt_sentence_filtered.append(tgt_sentence) write_data(src_sentence_filtered, filtered_src_file_path) write_data(tgt_sentence_filtered, filtered_tgt_file_path)
def sort_sentence(args): for src_file_path, tgt_file_path, sorted_src_file_path, sorted_tgt_file_path in zip( args.src_file_path_list, args.tgt_file_path_list, args.output_src_file_path_list, args.output_tgt_file_path_list): src_data = read_data(src_file_path) tgt_data = read_data(tgt_file_path) assert len(src_data) == len(tgt_data) src_data = [sentence.split() for sentence in src_data] tgt_data = [sentence.split() for sentence in tgt_data] src_data, tgt_data = sort_src_sentence_by_length( list(zip(src_data, tgt_data))) write_data(src_data, sorted_src_file_path) write_data(tgt_data, sorted_tgt_file_path)
def bilingual_bleu_calculation(args): bleu_score_type = args.bleu_score_type corpus_bleu = nltk_corpus_bleu if bleu_score_type == "nltk_bleu" else sacre_corpus_bleu reference_data = read_data(args.reference_path) if bleu_score_type == "nltk_bleu": reference_data = [[sentence.split()] for sentence in reference_data] else: reference_data = [reference_data] args.translation_path_list.sort(key=cmp) blue_score_data = [] for translation_path in args.translation_path_list: print("Translation in: {}\n".format(translation_path)) translation_data = read_data(translation_path) if bleu_score_type == "nltk_bleu": translation_data = [ sentence.split() for sentence in translation_data ] bleu_score = corpus_bleu(reference_data, translation_data) * 100 else: bleu_score = corpus_bleu(translation_data, reference_data) blue_score_data.append(bleu_score) print("bleu:{}".format(bleu_score)) print("Writing data to: {}".format(args.bleu_score_data_path)) with open(args.bleu_score_data_path) as f: json.dump(blue_score_data, f)
def check_word(args): assert args.vocab_path is not None with open(args.vocab_path, "rb") as f: vocab_data = pickle.load(f) vocab_data_all = [] for vocab in vocab_data.values(): vocab_data_all.extend(list(vocab)) vocab_data_all = set(vocab_data_all) lang_identifier = read_data(args.lang_identifier_path) for translation_file_path in args.translation_file_path_list: translation_data = read_data(translation_file_path) wrong_token = 0 num_token = 0 unk_token = 0 for i, sentence in enumerate(translation_data): lang_code = lang_identifier[i] sentence = sentence.split() for token in sentence: if token not in vocab_data[lang_code]: assert token in vocab_data_all if token == "UNK": unk_token += 1 else: wrong_token += 1 num_token += len(sentence) print( "Number of wrong tokens: {}, radio of wrong tokens: {}, number of UNK tokens: {}, radio of UNK tokens: {}" .format(wrong_token, wrong_token / num_token, unk_token, unk_token / num_token))
def check_marginal_distribution(path1, path2, attr_list): data1, domain, columns = tools.read_data(path1, domain_path) temp_data2, domain, columns = tools.read_data(path2, domain_path, dtype=None) data2 = np.zeros(shape=temp_data2.shape) data2[:, anchor_attr] = temp_data2[:, anchor_attr] for attr in attr_list: marginal1 = tools.get_marginal(data1, domain, (attr, )) marginal2 = tools.get_marginal(data2, domain, (attr, )) print(f'attr {attr}') print(marginal1[:20]) print(marginal2[:20]) dist_array = marginal1 - marginal2 print(dist_array[:20]) check_array = (marginal1 - marginal2)/(marginal1 + 1) print(check_array[:20]) abs_array = np.abs(check_array) print(f'{len(check_array)}: {np.sum(abs_array>0.01)} {np.sum(abs_array>0.05)} {np.sum(abs_array>0.10)}\n')
def truncate_and_plot(data_path, domain_path, prefix): data, domain, headings = tools.read_data(data_path, domain_path) threshold_map = {'fare': 100, 'tips': 20, 'trip_miles': 40, 'trip_total': 100, 'trip_seconds': 10000} for attr in threshold_map: data, domain = atools.truncate_data(data, domain, headings.index(attr), threshold_map[attr]) for attr in common_attrs: attr_id = headings.index(attr) ptools.plot_attr(attr_id, data, domain, \ path=f'./info/{prefix}_{attr}_{attr_id}.pdf') print(tools.get_marginal(data, domain, (headings.index('payment_type'), )) )
def remove_blank(args): for file in args.zh_corpus_list: data = read_data(file) data = ["".join(sentence.strip().split()) for sentence in data] directory, file_name = os.path.split(file) idx = file_name.rfind(".") assert idx != -1 new_file_name = file_name[:idx] + "_no_blank" + file_name[idx:] new_file = os.path.join(directory, new_file_name) write_data(data, new_file)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--multi_gpu_translation_dir", required=True) parser.add_argument("--is_tok", action="store_true") parser.add_argument("--merged_translation_dir", required=True) args, unknown = parser.parse_known_args() translations_dict_per_model = {} for file in os.listdir(args.multi_gpu_translation_dir): file_name_prefix, extension = os.path.splitext(file) if args.is_tok and not file_name_prefix.endswith("_tok"): continue elif not args.is_tok and file_name_prefix.endswith("_tok"): continue assert extension[1:5] == "rank" rank = int(extension[5:]) data = read_data(os.path.join(args.multi_gpu_translation_dir, file)) if file_name_prefix in translations_dict_per_model: translations_dict_per_model[file_name_prefix].append((data, rank)) else: translations_dict_per_model[file_name_prefix] = [(data, rank)] for file_name_prefix in translations_dict_per_model: translations_dict_per_model[file_name_prefix].sort( key=lambda item: item[1]) if not os.path.isdir(args.merged_translation_dir): os.makedirs(args.merged_translation_dir) for file_name_prefix in translations_dict_per_model: merged_translations = [] for translations, rank in translations_dict_per_model[ file_name_prefix]: merged_translations.extend(translations) write_data( merged_translations, os.path.join(args.merged_translation_dir, "{}.txt".format(file_name_prefix)))
def preprocess(input_path, domain_path, output_path): print(f'preprocesing {input_path} -> {output_path}') df, domain, columns = tools.read_data(input_path, domain_path, return_df=True) attr_num = df.shape[1] # print(df.shape) # print(columns) # print(df.iloc[:10, ]) for col in df.columns: if col in BINS: df.loc[:, col] = pd.cut(df[col], BINS[col], right=False, labels=False) if len(BINS[col]) < 255: df.loc[:, col] = df[col].astype(np.uint8) data = df.to_numpy(dtype=int) # print(np.unique(data[:, 2], return_counts=True)) # print(np.unique(data[:, 5], return_counts=True)) for attr in range(attr_num): if columns[attr] in [ 'pickup_community_area', 'dropoff_community_area' ]: mask = data[:, attr] == -1 print(attr, columns[attr], np.min(data[:, attr]), np.sum(mask)) data[:, attr][mask] = 0 # mask = data[:, attr] == 0 # print(attr, columns[attr], np.min(data[:, attr]), np.sum(mask)) elif columns[attr] == 'payment_type': # caution: 4 is not a valid value of payment_type in parameters.json print(attr, columns[attr], np.sum(data[:, attr] == 4)) mask = data[:, attr] == -1 data[:, attr][mask] = 4 # print(np.min(data)) # print(columns) df = pd.DataFrame(data, columns=columns, dtype=int) df.to_csv(output_path, index=False)
def truncate_and_plot(data_path, domain_path, prefix): data, domain, headings = tools.read_data(data_path, domain_path) threshold_map = {'fare': 100, 'tips': 20, 'trip_miles': 40, 'trip_total': 100, 'trip_seconds': 10000} for attr in threshold_map: data, domain = atools.truncate_data(data, domain, headings.index(attr), threshold_map[attr]) parameters_json = json.load(open('./data/parameters.json', 'r')) attrs = list(parameters_json['schema'].keys()) for attr in attrs: attr_id = headings.index(attr) ptools.plot_attr(attr_id, data, domain, \ path=f'./info/{prefix}_{attr}_{attr_id}.pdf') print(tools.get_marginal(data, domain, (headings.index('payment_type'), )) )
def main(): parser = argparse.ArgumentParser() parser.add_argument("--log_file_path", "-lf", required=True) parser.add_argument("--json_file_path", "-jf", required=True) args, unknown = parser.parse_known_args() log_data = read_data(args.log_file_path) data = [] element = {} for line in log_data: line = line.lower() if line.startswith("load model"): element["model_name"] = line[16:] elif line.startswith("bleu"): l = line.find("=") assert l != -1 l += 2 r = line.find(",") assert r != -1 bleu_score = float(line[l:r]) element["bleu"] = bleu_score element["bleu_details"] = line data.append(element) element = {} data.sort(key=lambda item: -item["bleu"]) with open(args.json_file_path, "w") as f: json.dump(data, f)
def repair_data(input_path, domain_path, output_path): print(f'repairing final data {input_path} -> {output_path}') df, domain, columns = tools.read_data(input_path, domain_path, return_df=True) attr_num = df.shape[1] data = df.to_numpy(dtype=int) for attr in range(attr_num): if columns[attr] in [ 'pickup_community_area', 'dropoff_community_area' ]: mask = data[:, attr] == 0 print(np.sum(mask)) data[:, attr][mask] = -1 df = pd.DataFrame(data, columns=columns, dtype=int) df.to_csv(output_path, index=False)
def tokenize(args): assert os.path.isdir(args.raw_corpus_dir) if not os.path.isdir(args.tokenized_corpus_dir): os.makedirs(args.tokenized_corpus_dir) tokenizer = sacrebleu.TOKENIZERS[sacrebleu.DEFAULT_TOKENIZER] for file in os.listdir(args.raw_corpus_dir): file_path = os.path.join(args.raw_corpus_dir, file) idx = file.rfind(".") assert idx != -1 new_file_name = "{}.{}.{}".format(file[:idx], "tok", file[idx+1:]) data = read_data(file_path) data_tok = [tokenizer(sentence) for sentence in data] new_file_path = os.path.join(args.tokenized_corpus_dir, new_file_name) write_data(data_tok, new_file_path)
def merge(corpus_path: str): train_data_src = [] train_data_tgt = [] dev_data_src = [] dev_data_tgt = [] test_data_src = [] test_data_tgt = [] un_uesd_corpus = {"pt-br_en", "fr-ca_en", "eo_en", "calv_en"} for corpus_dir in os.listdir(corpus_path): if corpus_dir in un_uesd_corpus: continue corpus_dir = os.path.join(corpus_path, corpus_dir) for corpus_file_name in os.listdir(corpus_dir): corpus_file_path = os.path.join(corpus_dir, corpus_file_name) print(corpus_file_path) data = read_data(corpus_file_path) is_en = True if not corpus_file_name.endswith(".en"): idx = corpus_file_name.rfind(".") assert idx != -1 idx += 1 lang_identify_token = "".join(["<", corpus_file_name[idx:], ">"]) data = [" ".join([lang_identify_token, sentence]) for sentence in data] is_en = False if corpus_file_name.startswith("train"): if is_en: train_data_tgt.extend(data) else: train_data_src.extend(data) elif corpus_file_name.startswith("dev"): if is_en: dev_data_tgt.extend(data) else: dev_data_src.extend(data) elif corpus_file_name.startswith("test"): if is_en: test_data_tgt.extend(data) else: test_data_src.extend(data) assert check(train_data_src, train_data_tgt) and check(dev_data_src, dev_data_tgt) and \ check(test_data_src, test_data_tgt) output_dir = "/data/rrjin/NMT/data/ted_data/corpus" if not os.path.isdir(output_dir): os.makedirs(output_dir) write_data(train_data_src, os.path.join(output_dir, "train_data_src.combine")) write_data(train_data_tgt, os.path.join(output_dir, "train_data_tgt.en")) write_data(dev_data_src, os.path.join(output_dir, "dev_data_src.combine")) write_data(dev_data_tgt, os.path.join(output_dir, "dev_data_tgt.en")) write_data(test_data_src, os.path.join(output_dir, "test_data_src.combine")) write_data(test_data_tgt, os.path.join(output_dir, "test_data_tgt.en"))
def evaluation(local_rank, args): rank = args.nr * args.gpus + local_rank dist.init_process_group(backend="nccl", init_method=args.init_method, rank=rank, world_size=args.world_size) device = torch.device("cuda", local_rank) torch.cuda.set_device(device) # List[str] src_data = read_data(args.test_src_path) tgt_prefix_data = None if args.tgt_prefix_file_path is not None: tgt_prefix_data = read_data(args.tgt_prefix_file_path) max_src_len = max(len(line.split()) for line in src_data) + 2 max_tgt_len = max_src_len * 3 logging.info("max src sentence length: {}".format(max_src_len)) src_vocab = Vocab.load(args.src_vocab_path) tgt_vocab = Vocab.load(args.tgt_vocab_path) padding_value = src_vocab.get_index(src_vocab.mask_token) assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token) src_data = convert_data_to_index(src_data, src_vocab) dataset = DataPartition(src_data, args.world_size, tgt_prefix_data, args.work_load_per_process).dataset(rank) logging.info("dataset size: {}, rank: {}".format(len(dataset), rank)) data_loader = DataLoader( dataset=dataset, batch_size=(args.batch_size if args.batch_size else 1), shuffle=False, pin_memory=True, drop_last=False, collate_fn=lambda batch: collate_eval( batch, padding_value, batch_first=(True if args.transformer else False))) if not os.path.isdir(args.translation_output_dir): os.makedirs(args.translation_output_dir) if args.beam_size: logging.info("Beam size: {}".format(args.beam_size)) if args.is_prefix: args.model_load = args.model_load + "*" for model_path in glob.glob(args.model_load): logging.info("Load model from: {}, rank: {}".format(model_path, rank)) if args.transformer: s2s = load_transformer(model_path, len(src_vocab), max_src_len, len(tgt_vocab), max_tgt_len, padding_value, training=False, share_dec_pro_emb=args.share_dec_pro_emb, device=device) else: s2s = load_model(model_path, device=device) s2s.eval() if args.record_time: import time start_time = time.time() pred_data = [] for data, tgt_prefix_batch in data_loader: if args.beam_size: pred_data.append( beam_search_decoding(s2s, data.to(device, non_blocking=True), tgt_vocab, args.beam_size, device)) else: pred_data.extend( greedy_decoding(s2s, data.to(device, non_blocking=True), tgt_vocab, device, tgt_prefix_batch)) if args.record_time: end_time = time.time() logging.info("Time spend: {} seconds, rank: {}".format( end_time - start_time, rank)) _, model_name = os.path.split(model_path) if args.beam_size: translation_file_name_prefix = "{}_beam_size_{}".format( model_name, args.beam_size) else: translation_file_name_prefix = "{}_greedy".format(model_name) p = os.path.join( args.translation_output_dir, "{}_translations.rank{}".format(translation_file_name_prefix, rank)) write_data(pred_data, p) if args.need_tok: # replace '@@ ' with '' p_tok = os.path.join( args.translation_output_dir, "{}_translations_tok.rank{}".format( translation_file_name_prefix, rank)) tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format( p, p_tok) call(tok_command, shell=True)
parser.add_argument("--need_tok", action="store_true") parser.add_argument("--batch_size", type=int) parser.add_argument("--bleu_script_path", required=True) args, unknown = parser.parse_known_args() device = args.device torch.cuda.set_device(device) src_vocab = Vocab.load(args.src_vocab_path) tgt_vocab = Vocab.load(args.tgt_vocab_path) src_data = read_data(args.test_src_path) # |------ for transformer ------| max_src_len = max(len(line.split()) for line in src_data) + 2 max_tgt_len = max_src_len * 3 print("max src sentence length: {}".format(max_src_len)) # |------------ end ------------| src_data = convert_data_to_index(src_data, src_vocab) src_data = SrcData(src_data) padding_value = src_vocab.get_index(src_vocab.mask_token) assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token) if args.batch_size: assert args.beam_size is None, "batch translation do not support bream search now"
def multilingual_bleu_calculation(args): assert args.lang_identifier_file_path is not None and args.language_data is not None bleu_score_type = args.bleu_score_type corpus_bleu = nltk_corpus_bleu if bleu_score_type == "nltk_bleu" else sacre_corpus_bleu lang_identifier_list = read_data(args.lang_identifier_file_path) lang_identifier = {} with open(args.language_data) as f: language_dict = json.load(f) for i, sentence in enumerate(lang_identifier_list): lang_code = sentence assert lang_code.startswith("<") and lang_code.endswith(">") lang_code = lang_code[1:-1] if lang_code not in lang_identifier: lang_identifier[lang_code] = [i] else: lang_identifier[lang_code].append(i) show_corpus_statistics(lang_identifier, language_dict) reference_data = read_data(args.reference_path) if bleu_score_type == "nltk_bleu": reference_data = [[sentence.split()] for sentence in reference_data] reference_data_per_language = {} for k, v in lang_identifier.items(): reference_data_per_language[k] = [ reference_data[line_number] for line_number in v ] if bleu_score_type == "sacrebleu": reference_data_per_language[k] = [reference_data_per_language[k]] args.translation_path_list.sort(key=cmp) bleu_score_dict = {lang: [] for lang in lang_identifier} for translation_path in args.translation_path_list: print("Translation in: {}\n".format(translation_path)) translation_data = read_data(translation_path) if bleu_score_type == "nltk_bleu": translation_data = [ sentence.split() for sentence in translation_data ] translation_data_per_language = {} for k, v in lang_identifier.items(): translation_data_per_language[k] = [ translation_data[line_number] for line_number in v ] for lang in lang_identifier: if bleu_score_type == "nltk_bleu": bleu_score = corpus_bleu( reference_data_per_language[lang], translation_data_per_language[lang]) * 100 else: bleu_score = corpus_bleu(translation_data_per_language[lang], reference_data_per_language[lang]) print("IOS639-2:{},IOS639-3:{},language name:{},bleu:{}".format( lang, language_dict[lang]["ISO639-3"], language_dict[lang]["language name"], bleu_score)) bleu_score_dict[lang].append(bleu_score if bleu_score_type == "nltk_bleu" else bleu_score.score) print() print("Writing data to: {}".format(args.bleu_score_data_path)) with open(args.bleu_score_data_path, "w") as f: json.dump(bleu_score_dict, f)
train_tgt = [] dev_src = [] dev_tgt = [] test_src = [] test_tgt = [] for directory in corpus_name: directory = os.path.join(prefix, directory) for file in os.listdir(directory): file_path = os.path.join(directory, file) data = read_data(file_path) if file.startswith("train"): if file.endswith(".en"): train_tgt.extend(data) else: train_src.extend(data) elif file.startswith("dev"): if file.endswith(".en"): dev_tgt.extend(data) else: dev_src.extend(data)