def make_binary_dataset(input_prefix, output_prefix, lng_pair, lang,
                            num_workers):
        if not args.joined_dictionary and lang != 'en':
            dict = dictionary.Dictionary.load(tgt_dict_path)
        else:
            dict = dictionary.Dictionary.load(dict_path)

        print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result['replaced'])
            n_seq_tok[0] += worker_result['nseq']
            n_seq_tok[1] += worker_result['ntok']

        input_file = f'{input_prefix}.{lng_pair}.{lang}.tok.bpe'
        if not os.path.exists(input_file):
            input_file = f'{input_prefix}.{lng_pair}.{lang}'
            if not os.path.exists(input_file):
                print("| {} not found".format(input_file))
                return
        if args.expert:
            input_file = input_file + '.e'
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                fn_without_ext = f"{output_prefix}{worker_id}.{lng_pair}.{lang}"
                pool.apply_async(binarize,
                                 (input_file, dict, fn_without_ext,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            f"{output_prefix}.{lng_pair}.{lang}.bin")
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                temp_file_path = f"{output_prefix}{worker_id}.{lng_pair}.{lang}"
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(f"{output_prefix}.{lng_pair}.{lang}.idx")

        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, n_seq_tok[0], n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))
예제 #2
0
    def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
        dict = dictionary.Dictionary.load(dict_path(lang))
        print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result['replaced'])
            n_seq_tok[0] += worker_result['nseq']
            n_seq_tok[1] += worker_result['ntok']

        input_file = '{}{}'.format(input_prefix,
                                   ('.' + lang) if lang is not None else '')
        offsets = Tokenizer.find_offsets(input_file, num_workers)
        print("offsets", offsets)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, dict, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.IndexedDatasetBuilder(
            dataset_dest_file(args, output_prefix, lang, 'bin'))
        merge_result(
            Tokenizer.binarize(input_file,
                               dict,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))

        print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
            lang, input_file, n_seq_tok[0], n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))
예제 #3
0
def get_offsets(input_file, num_workers):
    return Tokenizer.find_offsets(input_file, num_workers)