Exemplo n.º 1
0
def main():

    max_sentence_num = 100

    src_path = "/data/rrjin/NMT/data/bible_data/corpus/train_src_combine_joint_bpe_22000.txt"
    tgt_path = "/data/rrjin/NMT/data/bible_data/corpus/train_tgt_en_joint_bpe_22000.txt"

    src_data = read_data(src_path)
    tgt_data = read_data(tgt_path)

    src_freq = frequency(src_data)
    tgt_freq = frequency(tgt_data)

    src_freq = sort_freq(src_freq)
    tgt_freq = sort_freq(tgt_freq)

    _, src_file_name = os.path.split(src_path)
    _, tgt_file_name = os.path.split(tgt_path)

    plot_freq(src_freq, "statistic of " + src_file_name + ".jpg")
    plot_freq(tgt_freq, "statistic of " + tgt_file_name + ".jpg")

    sentence_number = len(src_data)

    cnt = 0

    for length, freq in src_freq:

        if length <= max_sentence_num:
            cnt += freq

    print(cnt)
    print(sentence_number)
Exemplo n.º 2
0
def remove_long_sentence(src_file_path_list: list, tgt_file_path_list: list,
                         max_sentence_length: int):

    assert len(src_file_path_list) == len(tgt_file_path_list)

    for src_file_path, tgt_file_path in zip(src_file_path_list,
                                            tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        for src_sentence, tgt_sentence in zip(src_data, tgt_data):

            if len(src_sentence.split()) <= max_sentence_length and len(
                    tgt_sentence.split()) <= max_sentence_length:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        new_src_file_path = get_file_path(src_file_path)
        new_tgt_file_path = get_file_path(tgt_file_path)

        write_data(src_sentence_filtered, new_src_file_path)
        write_data(tgt_sentence_filtered, new_tgt_file_path)
Exemplo n.º 3
0
def remove_long_sentence(args):

    assert args.max_sentence_length is not None

    for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        for src_sentence, tgt_sentence in zip(src_data, tgt_data):

            if len(src_sentence.split()) <= args.max_sentence_length and \
                    len(tgt_sentence.split()) <= args.max_sentence_length:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        write_data(src_sentence_filtered, filtered_src_file_path)
        write_data(tgt_sentence_filtered, filtered_tgt_file_path)
Exemplo n.º 4
0
def value_linkage_attack(path1, path2, domain_path, attr, value, sample_size=100):
    print('value linkage attack')
    # 从 path1 中找到几个 record 其 attr 值为 value
    # 从 path2 中找到 record,他们的 common attr 应当等于 之前的 record. 并且它们的 attr 为同一个 value
    # 这样就可以建立未知属性的 value 映射

    # 首先计算出 attr 的 marginal,对映射再次核查
    data1, domain, columns = tools.read_data(path1, domain_path)
    data2, domain, columns = tools.read_data(path2, domain_path, dtype=None)
    # print(data1[324], data2[324])

    data_num1 = len(data1)
    data_num2 = len(data2)

    temp_data1 = data1[:, anchor_attr]
    temp_data2 = data2[:, anchor_attr].astype(int)
    # print(temp_data1[324], temp_data2[324])

    print(f'  finding uqniue values')

    unique, unique_indices1, unique_counts = np.unique(temp_data1, axis=0, return_index=True, return_counts=True)
    unique_map1 = {tuple(unique[i]): (unique_counts[i], unique_indices1[i]) for i in range(len(unique))}
    print(f'  found unique value {len(unique)}')
    
    unique, unique_indices2, unique_counts = np.unique(temp_data2, axis=0, return_index=True, return_counts=True)
    unique_map2 = {tuple(unique[i]): (unique_counts[i], unique_indices2[i]) for i in range(len(unique))}
    print(f'  found unique value {len(unique)}')

    with open('./info/unique_line_idx.txt', 'w') as out_file:

        print('generating voting map')
        map_vote = {}
        for value, cnt_idx1 in unique_map1.items():
            cnt1, idx1 = cnt_idx1
            if cnt1 == 1 and value in unique_map2:
                cnt2, idx2 = unique_map2[value]
                if cnt2 == 1:
                    out_file.write(f'{idx1} {idx2}\n')

                    attr_value1 = str(data1[idx1, attr])
                    attr_value2 = str(data2[idx2, attr])

                    if attr_value1 not in map_vote:
                        map_vote[attr_value1] = {}
                    if attr_value2 not in map_vote[attr_value1]:
                        map_vote[attr_value1][attr_value2] = 0
                    map_vote[attr_value1][attr_value2] += 1

    print(len(map_vote))
    json.dump(map_vote, open('./info/map_vote.json', 'w'))

    for key, votes in map_vote.items():
        if len(votes) == 1:
            print(key, list(votes.items()))
        else:
            print('    ', key, list(votes.items()))
Exemplo n.º 5
0
    def __init__(self, src_file_path: str, tgt_file_path: str, seed: int):

        self.src_data = read_data(src_file_path)
        self.tgt_data = read_data(tgt_file_path)

        self.seed = seed

        assert len(self.src_data) == len(self.tgt_data)

        self.shuffled_src_data = []
        self.shuffled_tgt_data = []
Exemplo n.º 6
0
def check_unique_line(path1, path2, print_num=10, check_num=10000, attr=None):
    unique_idx = []
    with open('./info/unique_line_idx.txt', 'r') as in_file:
        for line in in_file:
            unique_idx.append([int(item) for item in line.split(' ')])
    random.shuffle(unique_idx)

    data1, domain, columns = tools.read_data(path1, domain_path)
    data2, domain, columns = tools.read_data(path2, domain_path, dtype=None)
    # print(data1[324], data2[324])


    # data3, columns3 = tools.read_csv('./data/2019data/taxi_trips_2019.csv')

    data_num1 = len(data1)
    data_num2 = len(data2)

    # print(data1[4924570])
    # print(data2[9665075])

    for i in range(print_num):
        if attr == None:
            print(data1[unique_idx[i][0]])
            print(data2[unique_idx[i][1]])
        else:
            print(data1[unique_idx[i][0]][attr])
            print(data2[unique_idx[i][1]][attr])
            # print(data3[unique_idx[i][1]][2])
            # print(data3[unique_idx[i][1]][3])
        print('')

    pass_flag = True
    for i in range(check_num):
        if attr == None:
            value1 = data1[unique_idx[i][0]][1:]
            value2 = data2[unique_idx[i][1]][1:]
        else:
            value1 = data1[unique_idx[i][0]][attr]
            value2 = data2[unique_idx[i][1]][attr]

        if not (value1 == value2).all():
            pass_flag  = False
            print(i, unique_idx[i])
            print(value1)
            print(value2)
            break

    print(pass_flag)
Exemplo n.º 7
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument("--data_path", nargs="+")
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--strip_accents", action="store_true")
    parser.add_argument("--tokenize_chinese_chars", action="store_true")

    args, unknown = parser.parse_known_args()

    print("do_lower_case: {}, strip_accents: {}, tokenize_chinese_chars: {}".
          format(args.do_lower_case, args.strip_accents,
                 args.tokenize_chinese_chars))

    tokenizer = BasicTokenizer(
        do_lower_case=args.do_lower_case,
        strip_accents=args.strip_accents,
        tokenize_chinese_chars=args.tokenize_chinese_chars)

    for data_path in args.data_path:

        data = read_data(data_path)
        tok_data = []

        if data_path.endswith(".en"):

            tok_data = [tokenizer.tokenize(sentence) for sentence in data]

        else:
            for sentence in data:

                sentence = sentence.strip()

                first_blank_pos = sentence.find(" ")

                if first_blank_pos != -1:

                    # do not process language identify token
                    lang_identify_token = sentence[:first_blank_pos]

                    sentence_after_process = tokenizer.tokenize(
                        sentence[first_blank_pos + 1:])
                    sentence_after_process = " ".join(
                        [lang_identify_token, sentence_after_process])

                    tok_data.append(sentence_after_process)

                else:
                    tok_data.append(sentence)

        idx = data_path.rfind(".")

        assert idx != -1

        tok_data_path = data_path[:idx] + "_tok" + data_path[idx:]

        print("{} > {}".format(data_path, tok_data_path))

        write_data(tok_data, tok_data_path)
Exemplo n.º 8
0
def preprocess(input_path, output_path):
    print(f'preprocesing {input_path} -> {output_path}')

    data, domain, columns = tools.read_data(input_path, domain_path)
    attr_num = data.shape[1]
    print(data.shape)

    for attr in range(attr_num):
        if columns[attr] in ['pickup_community_area', 'dropoff_community_area']:
            print(attr, columns[attr], np.min(data[:, attr]))
            mask = data[:, attr] == -1
            data[:, attr][mask] = 0
        elif columns[attr] in ['fare', 'tips', 'trip_total', 'trip_seconds', 'trip_miles']:
            print(attr, columns[attr], np.min(data[:, attr]))
            data[:, attr] += 1
        elif columns[attr] == 'payment_type':
            # caution: 4 is not a valid value of payment_type in parameters.json
            print(attr, columns[attr], np.sum(data[:, attr] == 4))
            mask = data[:, attr] == -1
            data[:, attr][mask] = 4

    # print(np.min(data))
    print(columns)

    data = list(data)

    with open(output_path, 'w') as output_file:
        writer = csv.writer(output_file)
        writer.writerow(columns)
        for line in data:
            writer.writerow(line)
Exemplo n.º 9
0
def convert(language_code_path: str, output_file_path: str):

    language_code = read_data(language_code_path)

    language_dict = OrderedDict()

    for code in language_code:

        alpha = {}

        code = code[1:-1]

        if len(code) == 2:
            alpha["alpha_2"] = code
        elif len(code) == 3:
            alpha["alpha_3"] = code
        else:
            language_dict[code] = {"ISO639-3": "unknown", "language name": "unknown"}
            continue

        language_data = pycountry.languages.get(**alpha)

        if language_data is None:
            language_dict[code] = {"ISO639-3": "unknown", "language name": "unknown"}
        else:
            language_dict[code] = {"ISO639-3": language_data.alpha_3, "language name": language_data.name}

    with open(output_file_path, "w") as f:
        json.dump(language_dict, f)
Exemplo n.º 10
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", "-f", required=True)

    args, unknown = parser.parse_known_args()

    data = read_data(args.file_path)

    directory, file_name = os.path.split(args.file_path)

    data_after_processing = []

    for line in data:

        line = "".join(c for c in line if c != "@")

        data_after_processing.append(line)

    idx = file_name.rfind(".")

    if idx == -1:
        new_file_name = file_name + "_remove_at"
    else:
        new_file_name = file_name[:idx] + "_remove_at" + file_name[idx:]

    write_data(data_after_processing, os.path.join(directory, new_file_name))
Exemplo n.º 11
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", required=True)
    parser.add_argument("--picture_path", required=True)
    parser.add_argument("--language_data", required=True)

    args, unknown = parser.parse_known_args()

    with open(args.language_data) as f:
        language_dict = json.load(f)

    data = read_data(args.data_path)

    sentence_num_per_language = {}

    for sentence in data:

        token_list = sentence.split()
        lang_code = token_list[0]

        assert lang_code.startswith("<") and lang_code.endswith(">")
        lang_code = lang_code[1:-1]

        sentence_num_per_language[lang_code] = sentence_num_per_language.get(lang_code, 0) + 1

    sentence_num_per_language = {k: v for k, v in sorted(sentence_num_per_language.items(), key=cmp)}

    plot_data(sentence_num_per_language, language_dict, args.picture_path)
Exemplo n.º 12
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--feature_name", required=True)
    parser.add_argument("--classify_method",
                        required=True,
                        choices=["svm", "logistic"])
    parser.add_argument("--lang_vec_path", required=True)
    parser.add_argument("--lang_name_path", required=True)
    parser.add_argument("--output_file_path", required=True)

    args, unknown = parser.parse_known_args()

    lang_name = read_data(args.lang_name_path)
    lang_vec = load_lang_vec(args.lang_vec_path)

    lang_alpha3 = {}

    for lang in lang_name:
        alpha3 = get_language_alpha3(lang[1:-1])
        if check_alpha3(alpha3):
            lang_alpha3[lang] = alpha3

    feature_name = args.feature_name
    features = l2v.get_features(list(lang_alpha3.values()),
                                feature_name,
                                header=True)

    train(args, lang_vec, lang_alpha3, features)
Exemplo n.º 13
0
def remove_same_sentence(args):

    if args.src_memory_path:
        src_memory_list = []

        for memory_path in args.src_memory_path:
            src_memory_list.extend(read_data(memory_path))

        src_memory = set(src_memory_list)

    else:
        src_memory = set()

    for src_file_path, tgt_file_path, filtered_src_file_path, filtered_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):

        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_sentence_filtered = []
        tgt_sentence_filtered = []

        sentence_visited = set()
        removed_sentence_id = set()

        for i, sentence in enumerate(src_data):

            if sentence in sentence_visited:
                removed_sentence_id.add(i)
            elif sentence in src_memory:
                removed_sentence_id.add(i)
                sentence_visited.add(sentence)
            else:
                sentence_visited.add(sentence)

        for i, (src_sentence,
                tgt_sentence) in enumerate(zip(src_data, tgt_data)):

            if i not in removed_sentence_id:
                src_sentence_filtered.append(src_sentence)
                tgt_sentence_filtered.append(tgt_sentence)

        write_data(src_sentence_filtered, filtered_src_file_path)
        write_data(tgt_sentence_filtered, filtered_tgt_file_path)
Exemplo n.º 14
0
def sort_sentence(args):
    for src_file_path, tgt_file_path, sorted_src_file_path, sorted_tgt_file_path in zip(
            args.src_file_path_list, args.tgt_file_path_list,
            args.output_src_file_path_list, args.output_tgt_file_path_list):
        src_data = read_data(src_file_path)
        tgt_data = read_data(tgt_file_path)

        assert len(src_data) == len(tgt_data)

        src_data = [sentence.split() for sentence in src_data]
        tgt_data = [sentence.split() for sentence in tgt_data]

        src_data, tgt_data = sort_src_sentence_by_length(
            list(zip(src_data, tgt_data)))

        write_data(src_data, sorted_src_file_path)
        write_data(tgt_data, sorted_tgt_file_path)
Exemplo n.º 15
0
def bilingual_bleu_calculation(args):

    bleu_score_type = args.bleu_score_type

    corpus_bleu = nltk_corpus_bleu if bleu_score_type == "nltk_bleu" else sacre_corpus_bleu

    reference_data = read_data(args.reference_path)

    if bleu_score_type == "nltk_bleu":
        reference_data = [[sentence.split()] for sentence in reference_data]
    else:
        reference_data = [reference_data]

    args.translation_path_list.sort(key=cmp)

    blue_score_data = []

    for translation_path in args.translation_path_list:

        print("Translation in: {}\n".format(translation_path))

        translation_data = read_data(translation_path)

        if bleu_score_type == "nltk_bleu":
            translation_data = [
                sentence.split() for sentence in translation_data
            ]
            bleu_score = corpus_bleu(reference_data, translation_data) * 100

        else:
            bleu_score = corpus_bleu(translation_data, reference_data)

        blue_score_data.append(bleu_score)

        print("bleu:{}".format(bleu_score))

    print("Writing data to: {}".format(args.bleu_score_data_path))

    with open(args.bleu_score_data_path) as f:
        json.dump(blue_score_data, f)
Exemplo n.º 16
0
def check_word(args):

    assert args.vocab_path is not None

    with open(args.vocab_path, "rb") as f:
        vocab_data = pickle.load(f)

    vocab_data_all = []
    for vocab in vocab_data.values():
        vocab_data_all.extend(list(vocab))

    vocab_data_all = set(vocab_data_all)

    lang_identifier = read_data(args.lang_identifier_path)

    for translation_file_path in args.translation_file_path_list:
        translation_data = read_data(translation_file_path)
        wrong_token = 0
        num_token = 0
        unk_token = 0
        for i, sentence in enumerate(translation_data):
            lang_code = lang_identifier[i]
            sentence = sentence.split()

            for token in sentence:
                if token not in vocab_data[lang_code]:
                    assert token in vocab_data_all
                    if token == "UNK":
                        unk_token += 1
                    else:
                        wrong_token += 1

            num_token += len(sentence)

        print(
            "Number of wrong tokens: {}, radio of wrong tokens: {}, number of UNK tokens: {}, radio of UNK tokens: {}"
            .format(wrong_token, wrong_token / num_token, unk_token,
                    unk_token / num_token))
Exemplo n.º 17
0
def check_marginal_distribution(path1, path2, attr_list):
    data1, domain, columns = tools.read_data(path1, domain_path)
    temp_data2, domain, columns = tools.read_data(path2, domain_path, dtype=None)

    data2 = np.zeros(shape=temp_data2.shape)
    data2[:, anchor_attr] = temp_data2[:, anchor_attr]

    for attr in attr_list:
        marginal1 = tools.get_marginal(data1, domain, (attr, ))
        marginal2 = tools.get_marginal(data2, domain, (attr, ))

        print(f'attr {attr}')
        print(marginal1[:20])
        print(marginal2[:20])

        dist_array = marginal1 - marginal2
        print(dist_array[:20])

        check_array = (marginal1 - marginal2)/(marginal1 + 1)
        print(check_array[:20])

        abs_array = np.abs(check_array)
        print(f'{len(check_array)}: {np.sum(abs_array>0.01)} {np.sum(abs_array>0.05)} {np.sum(abs_array>0.10)}\n')
Exemplo n.º 18
0
def truncate_and_plot(data_path, domain_path, prefix):
    data, domain, headings = tools.read_data(data_path, domain_path)

    threshold_map = {'fare': 100, 'tips': 20, 'trip_miles': 40, 'trip_total': 100, 'trip_seconds': 10000}
    for attr in threshold_map:
        data, domain = atools.truncate_data(data, domain, headings.index(attr), threshold_map[attr])

    for attr in common_attrs:
        attr_id = headings.index(attr)
        ptools.plot_attr(attr_id, data, domain, \
            path=f'./info/{prefix}_{attr}_{attr_id}.pdf')

    print(tools.get_marginal(data, domain, (headings.index('payment_type'), ))
    
    )
Exemplo n.º 19
0
def remove_blank(args):
    for file in args.zh_corpus_list:

        data = read_data(file)

        data = ["".join(sentence.strip().split()) for sentence in data]

        directory, file_name = os.path.split(file)
        idx = file_name.rfind(".")

        assert idx != -1

        new_file_name = file_name[:idx] + "_no_blank" + file_name[idx:]
        new_file = os.path.join(directory, new_file_name)
        write_data(data, new_file)
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--multi_gpu_translation_dir", required=True)
    parser.add_argument("--is_tok", action="store_true")
    parser.add_argument("--merged_translation_dir", required=True)

    args, unknown = parser.parse_known_args()

    translations_dict_per_model = {}

    for file in os.listdir(args.multi_gpu_translation_dir):

        file_name_prefix, extension = os.path.splitext(file)

        if args.is_tok and not file_name_prefix.endswith("_tok"):
            continue
        elif not args.is_tok and file_name_prefix.endswith("_tok"):
            continue

        assert extension[1:5] == "rank"
        rank = int(extension[5:])

        data = read_data(os.path.join(args.multi_gpu_translation_dir, file))

        if file_name_prefix in translations_dict_per_model:
            translations_dict_per_model[file_name_prefix].append((data, rank))
        else:
            translations_dict_per_model[file_name_prefix] = [(data, rank)]

    for file_name_prefix in translations_dict_per_model:
        translations_dict_per_model[file_name_prefix].sort(
            key=lambda item: item[1])

    if not os.path.isdir(args.merged_translation_dir):
        os.makedirs(args.merged_translation_dir)

    for file_name_prefix in translations_dict_per_model:

        merged_translations = []
        for translations, rank in translations_dict_per_model[
                file_name_prefix]:
            merged_translations.extend(translations)

        write_data(
            merged_translations,
            os.path.join(args.merged_translation_dir,
                         "{}.txt".format(file_name_prefix)))
Exemplo n.º 21
0
def preprocess(input_path, domain_path, output_path):
    print(f'preprocesing {input_path} -> {output_path}')

    df, domain, columns = tools.read_data(input_path,
                                          domain_path,
                                          return_df=True)
    attr_num = df.shape[1]
    # print(df.shape)
    # print(columns)
    # print(df.iloc[:10, ])

    for col in df.columns:
        if col in BINS:
            df.loc[:, col] = pd.cut(df[col],
                                    BINS[col],
                                    right=False,
                                    labels=False)
            if len(BINS[col]) < 255:
                df.loc[:, col] = df[col].astype(np.uint8)

    data = df.to_numpy(dtype=int)

    # print(np.unique(data[:, 2], return_counts=True))
    # print(np.unique(data[:, 5], return_counts=True))

    for attr in range(attr_num):
        if columns[attr] in [
                'pickup_community_area', 'dropoff_community_area'
        ]:
            mask = data[:, attr] == -1
            print(attr, columns[attr], np.min(data[:, attr]), np.sum(mask))
            data[:, attr][mask] = 0
            # mask = data[:, attr] == 0
            # print(attr, columns[attr], np.min(data[:, attr]), np.sum(mask))
        elif columns[attr] == 'payment_type':
            # caution: 4 is not a valid value of payment_type in parameters.json
            print(attr, columns[attr], np.sum(data[:, attr] == 4))
            mask = data[:, attr] == -1
            data[:, attr][mask] = 4

    # print(np.min(data))
    # print(columns)

    df = pd.DataFrame(data, columns=columns, dtype=int)
    df.to_csv(output_path, index=False)
Exemplo n.º 22
0
def truncate_and_plot(data_path, domain_path, prefix):
    data, domain, headings = tools.read_data(data_path, domain_path)

    threshold_map = {'fare': 100, 'tips': 20, 'trip_miles': 40, 'trip_total': 100, 'trip_seconds': 10000}
    for attr in threshold_map:
        data, domain = atools.truncate_data(data, domain, headings.index(attr), threshold_map[attr])

    parameters_json = json.load(open('./data/parameters.json', 'r'))
    attrs = list(parameters_json['schema'].keys())

    for attr in attrs:
        attr_id = headings.index(attr)
        ptools.plot_attr(attr_id, data, domain, \
            path=f'./info/{prefix}_{attr}_{attr_id}.pdf')

    print(tools.get_marginal(data, domain, (headings.index('payment_type'), ))
    
    )
Exemplo n.º 23
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--log_file_path", "-lf", required=True)
    parser.add_argument("--json_file_path", "-jf", required=True)

    args, unknown = parser.parse_known_args()

    log_data = read_data(args.log_file_path)

    data = []

    element = {}

    for line in log_data:

        line = line.lower()

        if line.startswith("load model"):
            element["model_name"] = line[16:]

        elif line.startswith("bleu"):

            l = line.find("=")
            assert l != -1
            l += 2

            r = line.find(",")
            assert r != -1

            bleu_score = float(line[l:r])
            element["bleu"] = bleu_score
            element["bleu_details"] = line

            data.append(element)

            element = {}

    data.sort(key=lambda item: -item["bleu"])

    with open(args.json_file_path, "w") as f:
        json.dump(data, f)
Exemplo n.º 24
0
def repair_data(input_path, domain_path, output_path):
    print(f'repairing final data {input_path} -> {output_path}')

    df, domain, columns = tools.read_data(input_path,
                                          domain_path,
                                          return_df=True)
    attr_num = df.shape[1]

    data = df.to_numpy(dtype=int)

    for attr in range(attr_num):
        if columns[attr] in [
                'pickup_community_area', 'dropoff_community_area'
        ]:
            mask = data[:, attr] == 0
            print(np.sum(mask))
            data[:, attr][mask] = -1

    df = pd.DataFrame(data, columns=columns, dtype=int)
    df.to_csv(output_path, index=False)
Exemplo n.º 25
0
def tokenize(args):

    assert os.path.isdir(args.raw_corpus_dir)

    if not os.path.isdir(args.tokenized_corpus_dir):
        os.makedirs(args.tokenized_corpus_dir)

    tokenizer = sacrebleu.TOKENIZERS[sacrebleu.DEFAULT_TOKENIZER]

    for file in os.listdir(args.raw_corpus_dir):

        file_path = os.path.join(args.raw_corpus_dir, file)

        idx = file.rfind(".")
        assert idx != -1

        new_file_name = "{}.{}.{}".format(file[:idx], "tok", file[idx+1:])

        data = read_data(file_path)

        data_tok = [tokenizer(sentence) for sentence in data]

        new_file_path = os.path.join(args.tokenized_corpus_dir, new_file_name)
        write_data(data_tok, new_file_path)
Exemplo n.º 26
0
def merge(corpus_path: str):

    train_data_src = []
    train_data_tgt = []

    dev_data_src = []
    dev_data_tgt = []

    test_data_src = []
    test_data_tgt = []

    un_uesd_corpus = {"pt-br_en", "fr-ca_en", "eo_en", "calv_en"}

    for corpus_dir in os.listdir(corpus_path):

        if corpus_dir in un_uesd_corpus:
            continue

        corpus_dir = os.path.join(corpus_path, corpus_dir)

        for corpus_file_name in os.listdir(corpus_dir):

            corpus_file_path = os.path.join(corpus_dir, corpus_file_name)
            print(corpus_file_path)

            data = read_data(corpus_file_path)

            is_en = True

            if not corpus_file_name.endswith(".en"):

                idx = corpus_file_name.rfind(".")
                assert idx != -1

                idx += 1
                lang_identify_token = "".join(["<", corpus_file_name[idx:], ">"])
                data = [" ".join([lang_identify_token, sentence]) for sentence in data]
                is_en = False

            if corpus_file_name.startswith("train"):

                if is_en:
                    train_data_tgt.extend(data)
                else:
                    train_data_src.extend(data)

            elif corpus_file_name.startswith("dev"):

                if is_en:
                    dev_data_tgt.extend(data)
                else:
                    dev_data_src.extend(data)

            elif corpus_file_name.startswith("test"):

                if is_en:
                    test_data_tgt.extend(data)
                else:
                    test_data_src.extend(data)

    assert check(train_data_src, train_data_tgt) and check(dev_data_src, dev_data_tgt) and \
           check(test_data_src, test_data_tgt)

    output_dir = "/data/rrjin/NMT/data/ted_data/corpus"

    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)

    write_data(train_data_src, os.path.join(output_dir, "train_data_src.combine"))
    write_data(train_data_tgt, os.path.join(output_dir, "train_data_tgt.en"))

    write_data(dev_data_src, os.path.join(output_dir, "dev_data_src.combine"))
    write_data(dev_data_tgt, os.path.join(output_dir, "dev_data_tgt.en"))

    write_data(test_data_src, os.path.join(output_dir, "test_data_src.combine"))
    write_data(test_data_tgt, os.path.join(output_dir, "test_data_tgt.en"))
Exemplo n.º 27
0
def evaluation(local_rank, args):
    rank = args.nr * args.gpus + local_rank
    dist.init_process_group(backend="nccl",
                            init_method=args.init_method,
                            rank=rank,
                            world_size=args.world_size)

    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(device)

    # List[str]
    src_data = read_data(args.test_src_path)

    tgt_prefix_data = None
    if args.tgt_prefix_file_path is not None:
        tgt_prefix_data = read_data(args.tgt_prefix_file_path)

    max_src_len = max(len(line.split()) for line in src_data) + 2
    max_tgt_len = max_src_len * 3
    logging.info("max src sentence length: {}".format(max_src_len))

    src_vocab = Vocab.load(args.src_vocab_path)
    tgt_vocab = Vocab.load(args.tgt_vocab_path)

    padding_value = src_vocab.get_index(src_vocab.mask_token)

    assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

    src_data = convert_data_to_index(src_data, src_vocab)

    dataset = DataPartition(src_data, args.world_size, tgt_prefix_data,
                            args.work_load_per_process).dataset(rank)

    logging.info("dataset size: {}, rank: {}".format(len(dataset), rank))

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=(args.batch_size if args.batch_size else 1),
        shuffle=False,
        pin_memory=True,
        drop_last=False,
        collate_fn=lambda batch: collate_eval(
            batch,
            padding_value,
            batch_first=(True if args.transformer else False)))

    if not os.path.isdir(args.translation_output_dir):
        os.makedirs(args.translation_output_dir)

    if args.beam_size:
        logging.info("Beam size: {}".format(args.beam_size))

    if args.is_prefix:
        args.model_load = args.model_load + "*"

    for model_path in glob.glob(args.model_load):
        logging.info("Load model from: {}, rank: {}".format(model_path, rank))

        if args.transformer:

            s2s = load_transformer(model_path,
                                   len(src_vocab),
                                   max_src_len,
                                   len(tgt_vocab),
                                   max_tgt_len,
                                   padding_value,
                                   training=False,
                                   share_dec_pro_emb=args.share_dec_pro_emb,
                                   device=device)

        else:
            s2s = load_model(model_path, device=device)

        s2s.eval()

        if args.record_time:
            import time
            start_time = time.time()

        pred_data = []

        for data, tgt_prefix_batch in data_loader:
            if args.beam_size:
                pred_data.append(
                    beam_search_decoding(s2s, data.to(device,
                                                      non_blocking=True),
                                         tgt_vocab, args.beam_size, device))
            else:
                pred_data.extend(
                    greedy_decoding(s2s, data.to(device, non_blocking=True),
                                    tgt_vocab, device, tgt_prefix_batch))

        if args.record_time:
            end_time = time.time()
            logging.info("Time spend: {} seconds, rank: {}".format(
                end_time - start_time, rank))

        _, model_name = os.path.split(model_path)

        if args.beam_size:
            translation_file_name_prefix = "{}_beam_size_{}".format(
                model_name, args.beam_size)
        else:
            translation_file_name_prefix = "{}_greedy".format(model_name)

        p = os.path.join(
            args.translation_output_dir,
            "{}_translations.rank{}".format(translation_file_name_prefix,
                                            rank))

        write_data(pred_data, p)

        if args.need_tok:

            # replace '@@ ' with ''
            p_tok = os.path.join(
                args.translation_output_dir,
                "{}_translations_tok.rank{}".format(
                    translation_file_name_prefix, rank))

            tok_command = "sed -r 's/(@@ )|(@@ ?$)//g' {} > {}".format(
                p, p_tok)

            call(tok_command, shell=True)
Exemplo n.º 28
0
parser.add_argument("--need_tok", action="store_true")

parser.add_argument("--batch_size", type=int)

parser.add_argument("--bleu_script_path", required=True)

args, unknown = parser.parse_known_args()

device = args.device

torch.cuda.set_device(device)

src_vocab = Vocab.load(args.src_vocab_path)
tgt_vocab = Vocab.load(args.tgt_vocab_path)

src_data = read_data(args.test_src_path)

# |------ for transformer ------|
max_src_len = max(len(line.split()) for line in src_data) + 2
max_tgt_len = max_src_len * 3
print("max src sentence length: {}".format(max_src_len))
# |------------ end ------------|

src_data = convert_data_to_index(src_data, src_vocab)
src_data = SrcData(src_data)

padding_value = src_vocab.get_index(src_vocab.mask_token)
assert padding_value == tgt_vocab.get_index(tgt_vocab.mask_token)

if args.batch_size:
    assert args.beam_size is None, "batch translation do not support bream search now"
Exemplo n.º 29
0
def multilingual_bleu_calculation(args):

    assert args.lang_identifier_file_path is not None and args.language_data is not None

    bleu_score_type = args.bleu_score_type

    corpus_bleu = nltk_corpus_bleu if bleu_score_type == "nltk_bleu" else sacre_corpus_bleu

    lang_identifier_list = read_data(args.lang_identifier_file_path)
    lang_identifier = {}

    with open(args.language_data) as f:
        language_dict = json.load(f)

    for i, sentence in enumerate(lang_identifier_list):

        lang_code = sentence

        assert lang_code.startswith("<") and lang_code.endswith(">")

        lang_code = lang_code[1:-1]

        if lang_code not in lang_identifier:
            lang_identifier[lang_code] = [i]
        else:
            lang_identifier[lang_code].append(i)

    show_corpus_statistics(lang_identifier, language_dict)

    reference_data = read_data(args.reference_path)

    if bleu_score_type == "nltk_bleu":
        reference_data = [[sentence.split()] for sentence in reference_data]

    reference_data_per_language = {}

    for k, v in lang_identifier.items():

        reference_data_per_language[k] = [
            reference_data[line_number] for line_number in v
        ]

        if bleu_score_type == "sacrebleu":
            reference_data_per_language[k] = [reference_data_per_language[k]]

    args.translation_path_list.sort(key=cmp)

    bleu_score_dict = {lang: [] for lang in lang_identifier}

    for translation_path in args.translation_path_list:

        print("Translation in: {}\n".format(translation_path))

        translation_data = read_data(translation_path)

        if bleu_score_type == "nltk_bleu":
            translation_data = [
                sentence.split() for sentence in translation_data
            ]

        translation_data_per_language = {}
        for k, v in lang_identifier.items():
            translation_data_per_language[k] = [
                translation_data[line_number] for line_number in v
            ]

        for lang in lang_identifier:

            if bleu_score_type == "nltk_bleu":
                bleu_score = corpus_bleu(
                    reference_data_per_language[lang],
                    translation_data_per_language[lang]) * 100
            else:
                bleu_score = corpus_bleu(translation_data_per_language[lang],
                                         reference_data_per_language[lang])

            print("IOS639-2:{},IOS639-3:{},language name:{},bleu:{}".format(
                lang, language_dict[lang]["ISO639-3"],
                language_dict[lang]["language name"], bleu_score))
            bleu_score_dict[lang].append(bleu_score if bleu_score_type ==
                                         "nltk_bleu" else bleu_score.score)
        print()

    print("Writing data to: {}".format(args.bleu_score_data_path))

    with open(args.bleu_score_data_path, "w") as f:
        json.dump(bleu_score_dict, f)
Exemplo n.º 30
0
train_tgt = []

dev_src = []
dev_tgt = []

test_src = []
test_tgt = []

for directory in corpus_name:

    directory = os.path.join(prefix, directory)

    for file in os.listdir(directory):

        file_path = os.path.join(directory, file)
        data = read_data(file_path)

        if file.startswith("train"):

            if file.endswith(".en"):
                train_tgt.extend(data)
            else:
                train_src.extend(data)

        elif file.startswith("dev"):

            if file.endswith(".en"):
                dev_tgt.extend(data)
            else:
                dev_src.extend(data)