Beispiel #1
0
def main(args):
    with codecs.open(args.data, 'r', 'utf8') as f:
        data = json.load(f)
    print('-- Loaded data from %s' % args.data)

    processor = FewRelProcessor(
        args.roberta,
        args.dataset_impl,
        args.append_eos,
    )
    pbar = tqdm.tqdm(
        total=sum([len(v) for _, v in data.items()]),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )

    vocab = Dictionary.load(os.path.join(args.roberta, 'roberta.base', 'dict.txt'))
    dataset_builder = indexed_dataset.make_builder(
        args.output + '.text.bin',
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )
    relations_builder = indexed_dataset.make_builder(
        args.output + '.relations.bin',
        impl=args.dataset_impl,
        vocab_size=None,
    )
    processor.initializer()
    annotations_list = list()
    total_length, num_sentences = 0, 0
    for relation_type_id, (_, samples) in enumerate(data.items()):
        for ids_tensor, _annotations_list in map(processor, samples):
            dataset_builder.add_item(ids_tensor)
            relations_builder.add_item(torch.IntTensor([relation_type_id]))
            _annotations_list[:, 0] += total_length
            _annotations_list[:, 1] += total_length
            _annotations_list[:, 2] += num_sentences
            _annotations_list[:, 3] += num_sentences
            num_sentences += 1
            total_length += len(ids_tensor)
            annotations_list.append(_annotations_list)

            pbar.update()
    pbar.close()

    dataset_builder.finalize(args.output + '.text.idx')
    relations_builder.finalize(args.output + '.relations.idx')
    annotations_list = np.concatenate(annotations_list)
    np.save(args.output + '.annotations', annotations_list)
Beispiel #2
0
    def test_huffman_compresses(self):
        data = make_data()
        builder = make_code_builder(data)
        coder = builder.build_code()

        with TemporaryDirectory() as dirname:
            prefix = os.path.join(dirname, "huffman")
            build_dataset(prefix, data, coder)

            prefix_mmap = os.path.join(dirname, "mmap")
            mmap_builder = indexed_dataset.make_builder(
                indexed_dataset.data_file_path(prefix_mmap),
                "mmap",
                vocab_size=len(POPULATION),
            )
            dictionary = Dictionary()
            for c in POPULATION:
                dictionary.add_symbol(c)
            dictionary.finalize()
            for sentence in data:
                mmap_builder.add_item(dictionary.encode_line(" ".join(sentence)))
            mmap_builder.finalize(indexed_dataset.index_file_path(prefix_mmap))

            huff_size = os.stat(indexed_dataset.data_file_path(prefix)).st_size
            mmap_size = os.stat(indexed_dataset.data_file_path(prefix_mmap)).st_size
            self.assertLess(huff_size, mmap_size)
Beispiel #3
0
    def make_binary_tag_dataset(input_prefix, output_prefix, lang,
                                num_workers):
        logger.info("Adding tag indexes")
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            n_seq_tok[0] += worker_result["nseq"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        assert num_workers == 1

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
        )
        merge_result(
            Binarizer.binarize_tag(input_file,
                                   lambda t: ds.add_item(t),
                                   offset=0,
                                   end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Beispiel #4
0
    def _binarize_file_chunk(
        binarizer: Binarizer,
        filename: str,
        offset_start: int,
        offset_end: int,
        output_prefix: str,
        dataset_impl: str,
        vocab_size=None,
    ) -> tp.Tuple[tp.Any, BinarizeSummary]:  # (dataset builder, BinarizeSummary)
        """
        creates a dataset builder and append binarized items to it. This function does not
        finalize the builder, this is useful if you want to do other things with your bin file
        like appending/merging other files
        """
        bin_file = indexed_dataset.data_file_path(output_prefix)
        ds = indexed_dataset.make_builder(
            bin_file,
            impl=dataset_impl,
            vocab_size=vocab_size,
        )
        summary = BinarizeSummary()

        with Chunker(
            PathManager.get_local_path(filename), offset_start, offset_end
        ) as line_iterator:
            for line in line_iterator:
                ds.add_item(binarizer.binarize_line(line, summary))

        return ds, summary
    def make_binary_dataset(vocab, input_prefix, output_prefix, src_lang,
                            tgt_lang, lang, num_workers):
        print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}.{}-{}.{}".format(input_prefix, src_lang, tgt_lang,
                                          lang)
        if args.model:
            input_file += ".tok"
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = multiprocessing.Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (args, input_file, vocab, prefix, src_lang, tgt_lang, lang,
                     offsets[worker_id], offsets[worker_id + 1]),
                    callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, src_lang, tgt_lang, lang, "bin"),
                                          impl=args.dataset_impl,
                                          vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, src_lang,
                                                     tgt_lang, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(
            dataset_dest_file(args, output_prefix, src_lang, tgt_lang, lang,
                              "idx"))

        print("| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
            100 * sum(replaced.values()) / n_seq_tok[1],
            vocab.unk_word,
        ))
Beispiel #6
0
def binarize(args,
             filename,
             vocab,
             output_prefix,
             lang,
             offset,
             end,
             append_eos=True):
    ds = indexed_dataset.make_builder(dataset_dest_file(
        args, output_prefix, lang, "bin"),
                                      impl=args.dataset_impl,
                                      vocab_size=len(vocab))

    def consumer(tensor):
        ds.add_item(tensor)

    tk = tokenize_smiles if args.file_format == 'smiles' else tokenize_line
    res = Binarizer.binarize(filename,
                             vocab,
                             consumer,
                             tokenize=tk,
                             append_eos=append_eos,
                             offset=offset,
                             end=end)
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
    return res
Beispiel #7
0
def save_items(items, prefix, vocab_size):
  bin_fn = "%s.bin" % prefix
  idx_fn = "%s.idx" % prefix
  builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size)
  print("builder: " + str(builder))
  for item in items: builder.add_item(item)
  builder.finalize(idx_fn)
Beispiel #8
0
def binarize(args,
             filename,
             vocab,
             output_prefix,
             lang,
             offset,
             end,
             append_eos=True):
    ds = indexed_dataset.make_builder(
        dataset_dest_file(args, output_prefix, lang, "bin"),
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )

    def consumer(tensor):
        ds.add_item(tensor)

    res = Binarizer.binarize(filename,
                             vocab,
                             consumer,
                             append_eos=append_eos,
                             offset=offset,
                             end=end)
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
    return res
Beispiel #9
0
    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = find_offsets(input_file, num_workers)
        (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id, (start_offset, end_offset) in enumerate(more_chunks,
                                                                   start=1):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        start_offset,
                        end_offset,
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=first_chunk[0],
                end=first_chunk[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))
Beispiel #10
0
def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
    ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
                                      impl=args.dataset_impl, vocab_size=None)

    def consumer(tensor):
        ds.add_item(tensor)

    res = Binarizer.binarize_alignments(filename, parse_alignment, consumer, offset=offset,
                                        end=end)
    ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))
    return res
Beispiel #11
0
    def make_binary_da_dataset(input_prefix, output_prefix, lang, num_workers,
                               da_mapping):
        logger.info("Adding domain indexes")
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            n_seq_tok[0] += worker_result["nseq"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        # TODO: Error encounters if num_workers>1:
        #  No such file or directory: 'data-bin/iwslt14.tokenized.de-en/train.da1.en-de.en.idx'
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_da,
                    (
                        args,
                        input_file,
                        prefix,
                        lang,
                        da_mapping,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
        )
        merge_result(
            Binarizer.binarize_da(input_file,
                                  lambda t: ds.add_item(t),
                                  da_mapping,
                                  offset=0,
                                  end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
Beispiel #12
0
    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=0,
                end=offsets[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        print("| [alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))
    def make_binary_sent_doc_dataset(input_prefix, output_prefix, lang,
                                     num_workers):
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(binarize_sent_doc,
                                 (args, input_file, prefix, lang,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, lang, "bin"),
                                          impl=args.dataset_impl)
        merge_result(
            Binarizer.binarize_sent_doc(input_file,
                                        lambda t: ds.add_item(t),
                                        offset=0,
                                        end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        print("| [{}] {}: {} sents, {} tokens".format(
            lang,
            input_file,
            n_seq_tok[0],
            n_seq_tok[1],
        ))
Beispiel #14
0
def binarize(args,
             filename,
             vocab,
             output_prefix,
             lang,
             offset,
             end,
             append_eos=True):
    ##dataset_impl=mmap, ds -> MMapIndexedDatasetBuilder
    ##dataset_impl=lazy, ds -> IndexedDatasetBuilder
    ds = indexed_dataset.make_builder(dataset_dest_file(
        args, output_prefix, lang, "bin"),
                                      impl=args.dataset_impl,
                                      vocab_size=len(vocab))

    def consumer(tensor):
        ##输入的tensor,就是直接把文本串通过dictionary转为id串
        ##dataset_impl=mmap, MMapIndexedDatasetBuilder.add_item:把输入tensor直接写入文件
        ##dataset_impl=lazy, IndexedDatasetBuilder.add_item:把输入tensor写入文件,并更新sizes, data_offsets, dim_offsets
        ds.add_item(tensor)

    ## 读入文件filename在offset和end之间的内容,并把每个文本串利用dictionary转为id串,利用consumer函数写入到ds中
    res = Binarizer.binarize(filename,
                             vocab,
                             consumer,
                             append_eos=append_eos,
                             offset=offset,
                             end=end)

    ##把写入到ds中的数据存储到对应路径的临时文件, output_prefix包含了worker_id,以区分不同的worker的临时文件
    ##mmap, ds.finalize:调用MMapIndexedDataset.write写入三个tensor:
    ##      训练样例数量,每个样例tensor的size,每个样例的位置pointer
    ##lazy, ds.finalize:IndexedDatasetBuilder.finalize直接写入dim_offsets, data_offsets, sizes
    #       data_offsets: 存放每个tensor在二进制文件中的结尾位置(前一个tensor的结尾处就是这个tensor的开始位置)
    #       sizes: 存放每个tensor的shape的各个dim值
    #       dim_offsets: 存放每个tensor的shape在self.size中的结尾位置(前面tensor shape的结尾是这个tensor shape的开始)
    ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))
    return res
Beispiel #15
0
    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
        output_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        input_prefix += '.bert' if isinstance(vocab, BertTokenizer) else ''
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        merge_result(
            Binarizer.binarize(input_file,
                               vocab,
                               lambda t: ds.add_item(t),
                               offset=0,
                               end=offsets[1]))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))
def main(args):
    # if args.split in ['train', 'dev']:
    #     data_path = os.path.join(args.root_dir, 'SemEval2010_task8_training', 'TRAIN_FILE.TXT')
    #     rand_indices = np.random.permutation(8000)
    #     train_indices, dev_indices = rand_indices[:6500], rand_indices[6500:]
    # elif args.split == 'test':
    #     data_path = os.path.join(args.root_dir, 'SemEval2010_task8_testing_keys', 'TEST_FILE_FULL.TXT')

    if args.split == 'train':
        data_path = os.path.join(args.root_dir, 'SemEval2010_task8_training',
                                 'TRAIN_FILE.TXT')
    elif args.split == 'dev':
        data_path = os.path.join(args.root_dir,
                                 'SemEval2010_task8_testing_keys',
                                 'TEST_FILE_FULL.TXT')
    else:
        raise NotImplementedError

    with open(data_path, 'r') as f:
        data = f.read().splitlines()

    data = [x for x in data if x != '']
    texts = [x.split('\t')[1][1:-1] for x in data[0::3]]
    relation_types = [x.replace(' ', '') for x in data[1::3]]
    unique_relation_types = sorted(list(set(relation_types)))
    unique_relation_types.remove('Other')
    unique_relation_types.append('Other')

    # if args.split == 'train':
    #     texts = list(itemgetter(*train_indices)(texts))
    #     relation_types = list(itemgetter(*train_indices)(relation_types))
    # elif args.split == 'dev':
    #     texts = list(itemgetter(*dev_indices)(texts))
    #     relation_types = list(itemgetter(*dev_indices)(relation_types))

    print('-- Loaded data from %s' % data_path)

    processor = SemEval2010Task8Processor(
        args.roberta_dir,
        args.dataset_impl,
        args.append_eos,
    )
    pbar = tqdm.tqdm(
        total=len(texts),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )

    output_dir = os.path.join(args.root_dir, 'bin')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    split = args.split if args.split != 'dev' else 'valid'
    vocab = Dictionary.load(
        os.path.join(args.roberta_dir, 'roberta.base', 'dict.txt'))
    dataset_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split + '.text.bin'),
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )
    relations_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split + '.relations.bin'),
        impl=args.dataset_impl,
        vocab_size=None,
    )
    processor.initializer()
    annotations_list = list()
    total_length, num_sentences = 0, 0
    for i in range(len(texts)):
        text, relation_type_id = [texts[i]], unique_relation_types.index(
            relation_types[i])
        for ids_tensor, _annotations_list in map(processor, text):
            dataset_builder.add_item(ids_tensor)
            relations_builder.add_item(torch.IntTensor([relation_type_id]))
            _annotations_list[:, 0] += total_length
            _annotations_list[:, 1] += total_length
            _annotations_list[:, 2] += num_sentences
            _annotations_list[:, 3] += num_sentences
            num_sentences += 1
            total_length += len(ids_tensor)
            annotations_list.append(_annotations_list)

            pbar.update()
    pbar.close()

    dataset_builder.finalize(os.path.join(output_dir, split + '.text.idx'))
    relations_builder.finalize(
        os.path.join(output_dir, split + '.relations.idx'))
    annotations_list = np.concatenate(annotations_list)
    np.save(os.path.join(output_dir, split + '.annotations'), annotations_list)
Beispiel #17
0
def main(args):
    data_path = os.path.join(args.root_dir, args.split+'.txt')
    with open(data_path, 'r') as f:
        data = f.read().splitlines()
    data = [x for x in data if x != '']
    texts = [x.split('\t')[1][1:-1] for x in data[0::2]]
    relation_types = [x.replace(' ', '') for x in data[1::2]]
    unique_relation_types = sorted(list(set(relation_types)))
    unique_relation_types.remove('no_relation')
    unique_relation_types.append('no_relation')
    print('-- Loaded data from %s' % data_path)

    processor = KBP37Processor(
        args.roberta_dir,
        args.dataset_impl,
        args.append_eos,
        args.max_positions
    )
    pbar = tqdm.tqdm(
        total=len(texts),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )

    output_dir = os.path.join(args.root_dir, 'bin')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    split = args.split if args.split != 'dev' else 'valid'
    vocab = Dictionary.load(os.path.join(args.roberta_dir, 'roberta.base', 'dict.txt'))
    dataset_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split+'.text.bin'),
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )
    relations_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split+'.relations.bin'),
        impl=args.dataset_impl,
        vocab_size=None,
    )
    processor.initializer()
    annotations_list = list()
    total_length, num_sentences = 0, 0
    for i in range(len(texts)):
        text, relation_type_id = [texts[i]], unique_relation_types.index(relation_types[i])
        for ids_tensor, _annotations_list in map(processor, text):
            if ids_tensor is None:
                continue
            dataset_builder.add_item(ids_tensor)
            relations_builder.add_item(torch.IntTensor([relation_type_id]))
            _annotations_list[:, 0] += total_length
            _annotations_list[:, 1] += total_length
            _annotations_list[:, 2] += num_sentences
            _annotations_list[:, 3] += num_sentences
            num_sentences += 1
            total_length += len(ids_tensor)
            annotations_list.append(_annotations_list)

            pbar.update()
    pbar.close()

    dataset_builder.finalize(os.path.join(output_dir, split+'.text.idx'))
    relations_builder.finalize(os.path.join(output_dir, split+'.relations.idx'))
    annotations_list = np.concatenate(annotations_list)
    np.save(os.path.join(output_dir, split+'.annotations'), annotations_list)
Beispiel #18
0
    def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers, avoid_tokenize=False):
        if vocab is not None:
            print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1))
        else:
            print('| Using None Dictionary and only string split is performed.')

        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(
            input_prefix, ("." + lang) if lang is not None else ""
        )
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                # TODO: worker > 1 is not working for map dataset
                if args.input_mapping is True:
                    raise NotImplementedError("Worker > 1 is not implemented for map dataset yet.")
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        offsets[worker_id],
                        offsets[worker_id + 1],
                        avoid_tokenize,
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab) if vocab is not None else -1,
        )
        merge_result(
            Binarizer.binarize(
                input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1], avoid_tokenize=avoid_tokenize,
            )
        )
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        if vocab is not None:
            unk = vocab.unk_word if hasattr(vocab, 'unk_word') else vocab.unk_token
        else:
            unk = ""
        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                unk,
            )
        )
def make_binary_stack(args, target_vocab, input_prefix, output_prefix, eos_idx, pad_idx, mask_predicates=False, allow_unk=False, tokenize=None):

    assert tokenize

    # involved files
    # for debug
    input_senteces = input_prefix + '.en'
    input_actions = input_prefix + '.actions'

    # The AMR state machine allways expects rules
    if args.machine_type == 'AMR':

        assert args.machine_rules and os.path.isfile(args.machine_rules), \
            f'Missing {args.machine_rules}'

        # Read rules
        train_rule_stats = read_rule_stats(args.machine_rules)
        actions_by_stack_rules = train_rule_stats['possible_predicates']
    else:
        actions_by_stack_rules = None
        
    action_indexer = get_action_indexer(target_vocab.symbols)

    # initialize indices for each of variables
    # memory (stack, buffer, dead) (position in memory)
    stack_buffer_names = ['memory', 'memory_pos']
    # FIXME: These values are hard-coded elsewhere in code
    state_indices = [3, 4, 5]
    assert eos_idx not in state_indices, "Can not reuse EOS index"
    assert pad_idx not in state_indices, "Can not reuse PAD index"
    indexed_data = {}
    for name in stack_buffer_names:
        indexed_data[name] = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, name, "bin"),
            impl=args.dataset_impl,
        )

    if mask_predicates:
        # mask of target predictions 
        masks_path = dataset_dest_file(args, output_prefix, 'target_masks', "bin")
        indexed_target_masks = indexed_dataset.make_builder(
            masks_path,
            impl=args.dataset_impl,
        )

        # active indices 
        active_logits_path = dataset_dest_file(args, output_prefix, 'active_logits', "bin")
        indexed_active_logits = indexed_dataset.make_builder(
            active_logits_path,
            impl=args.dataset_impl,
        )

    # Returns function that generates initialized state machines given
    # sentence 
    get_new_state_machine = machine_generator(actions_by_stack_rules, entity_rules=args.entity_rules)

    num_sents = 0
    missing_actions = Counter()
    with open(input_actions, 'r') as fid_act, \
         open(input_senteces, 'r') as fid_sent:

        # Loop over sentences
        for sentence in tqdm(fid_sent):

            # Get actions, tokens
            sent_tokens = tokenize(sentence)
            sent_actions = tokenize(fid_act.readline())

            # intialize state machine batch for size 1
            state_machine = get_new_state_machine(
                sent_tokens,
                machine_type=args.machine_type
            )

            # collect target and source masks
            sent_data = {}
            for name in stack_buffer_names:
                sent_data[name] = []

            shape = (len(sent_actions), len(target_vocab.symbols))
            logits_mask = np.zeros(shape)
            active_logits = set()
            for action_idx, gold_action in enumerate(sent_actions):

                # active logits for this action
                if mask_predicates:

                    # Get total valid actions by expanding base ones
                    valid_actions, invalid_actions = state_machine.get_valid_actions()
                    valid_action_idx = (
                        action_indexer(valid_actions) 
                        - action_indexer(invalid_actions)
                    )

                    # if action is missing add it and count it
                    if gold_action in target_vocab.symbols:
                        gold_action_index = target_vocab.symbols.index(gold_action) 
                    else:
                        gold_action_index = target_vocab.symbols.index('<unk>') 
                    if gold_action_index not in valid_action_idx:
                        valid_action_idx.add(gold_action_index)
                        missing_actions.update([gold_action])

                    # if length 1 add pad to avoid deltas during learning
                    if len(valid_action_idx) == 1:
                        valid_action_idx.add(pad_idx)

                    # append number of nodes to regain matrix
                    logits_mask[action_idx, list(valid_action_idx)] = 1
                    active_logits |= valid_action_idx

                # stack and buffer 
                memory, memory_pos = get_word_states(
                    state_machine,
                    sent_tokens,
                    indices=state_indices
                )

                # word states
                sent_data['memory'].append(torch.Tensor(memory))
                # note we use position 0 for reduced words
                sent_data['memory_pos'].append(
                    torch.Tensor(memory_pos)
                )

                # Update machine
                state_machine.applyAction(gold_action)

            for name in stack_buffer_names:
                # note that data needs to be stores as a 1d array
                indexed_data[name].add_item(
                    torch.stack(sent_data[name]).view(-1)
                )

            # valid nodes
            if mask_predicates:
                active_logits = list(active_logits)
                # reduce size to active items
                logits_mask = logits_mask[:, active_logits]
                indexed_target_masks.add_item(
                    torch.Tensor(logits_mask).view(-1)
                )

                # active indices 
                indexed_active_logits.add_item(torch.Tensor(
                    active_logits
                ))

            # update number of sents
            num_sents += 1
            if not num_sents % 100:
                print("\r%d sentences" % num_sents, end = '')

        print("")

    # close indexed data files
    for name in stack_buffer_names:
        output_file_idx = dataset_dest_file(args, output_prefix, name, "idx")
        indexed_data[name].finalize(output_file_idx)
    # close valid action mask
    if mask_predicates:
        target_mask_idx = dataset_dest_file(args, output_prefix, 'target_masks', "idx")
        indexed_target_masks.finalize(target_mask_idx)

        # active indices 
        active_logits_idx = dataset_dest_file(args, output_prefix, 'active_logits', "idx")
        indexed_active_logits.finalize(active_logits_idx)

    # inform about mssing actions
    if missing_actions: 
        print(yellow_font("There were missing actions"))
        print(missing_actions)
def make_binary_bert_features(args, input_prefix, output_prefix, eos_idx, pad_idx, tokenize):

    # Load pretrained embeddings extractor
    pretrained_embeddings = PretrainedEmbeddings(
        args.pretrained_embed,
        args.bert_layers
    )

    # will store pre-extracted BERT layer
    indexed_data = indexed_dataset.make_builder(
        dataset_dest_file(args, output_prefix, 'en.bert', "bin"),
        impl=args.dataset_impl,
        dtype=np.float32
    )

    # will store wordpieces and wordpiece to word mapping
    indexed_wordpieces = indexed_dataset.make_builder(
        dataset_dest_file(args, output_prefix, 'en.wordpieces', "bin"),
        impl=args.dataset_impl,
    )

    indexed_wp2w = indexed_dataset.make_builder(
        dataset_dest_file(args, output_prefix, 'en.wp2w', "bin"),
        impl=args.dataset_impl,
    )

    num_sents = 0
    input_file = input_prefix + '.en'

    with open(input_file, 'r') as fid:
        for sentence in fid:

            # we only have tokenized data so we feed whitespace separated
            # tokens
            sentence = " ".join(tokenize(str(sentence).rstrip()))

            # extract embeddings, average them per token and return
            # wordpieces anyway
            word_features, worpieces_roberta, word2piece = \
                pretrained_embeddings.extract(sentence)

            # note that data needs to be stored as a 1d array. Also check
            # that number nof woprds matches with embedding size
            assert word_features.shape[1] == len(sentence.split())
            indexed_data.add_item(word_features.cpu().view(-1))

            # just store the wordpiece indices, ignore BOS/EOS tokens
            indexed_wordpieces.add_item(worpieces_roberta)
            indexed_wp2w.add_item(
                get_scatter_indices(word2piece, reverse=True)
            )

            # udpate number of sents
            num_sents += 1
            if not num_sents % 100:
                print("\r%d sentences" % num_sents, end = '')
        print("")

    # close indexed data files
    indexed_data.finalize(
        dataset_dest_file(args, output_prefix, 'en.bert', "idx")
    )

    indexed_wordpieces.finalize(
        dataset_dest_file(args, output_prefix, 'en.wordpieces', "idx")
    )
    indexed_wp2w.finalize(
        dataset_dest_file(args, output_prefix, 'en.wp2w', "idx")
    )
Beispiel #21
0
def main(args):
    data_path = os.path.join(args.root_dir, 'json', args.split + '.json')
    with codecs.open(data_path, 'r', 'utf8') as f:
        data = json.load(f)
    print('-- Loaded data from %s' % data_path)

    relations_path = os.path.join(args.root_dir, 'gold', args.split + '.gold')
    with open(relations_path, 'r') as f:
        relation_types = f.read().splitlines()
    unique_relation_types = sorted(list(set(relation_types)))
    unique_relation_types.remove('no_relation')
    unique_relation_types.append('no_relation')

    processor = TACREDProcessor(args.roberta_dir, args.dataset_impl,
                                args.append_eos, args.max_positions)
    pbar = tqdm.tqdm(
        total=len(data),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )

    output_dir = os.path.join(args.root_dir, 'bin')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    split = args.split if args.split != 'dev' else 'valid'
    vocab = Dictionary.load(
        os.path.join(args.roberta_dir, 'roberta.base', 'dict.txt'))
    dataset_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split + '.text.bin'),
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )
    relations_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split + '.relations.bin'),
        impl=args.dataset_impl,
        vocab_size=None,
    )
    processor.initializer()
    annotations_list = list()
    total_length, num_sentences = 0, 0
    for sample in data:
        tokens = [sample['token']]
        annot = [{
            1: {
                'start': sample['subj_start'],
                'end': sample['subj_end'] + 1
            },
            2: {
                'start': sample['obj_start'],
                'end': sample['obj_end'] + 1
            }
        }]
        relation_type_id = unique_relation_types.index(sample['relation'])
        # relation_type_id = unique_relation_types.index(sample['relation']) - 1
        # if relation_type_id == -1:
        #     continue
        for ids_tensor, _annotations_list in map(processor, tokens, annot):
            if ids_tensor is None:
                continue
            dataset_builder.add_item(ids_tensor)
            relations_builder.add_item(torch.IntTensor([relation_type_id]))
            _annotations_list[:, 0] += total_length
            _annotations_list[:, 1] += total_length
            _annotations_list[:, 2] += num_sentences
            _annotations_list[:, 3] += num_sentences
            num_sentences += 1
            total_length += len(ids_tensor)
            annotations_list.append(_annotations_list)

            pbar.update()
    pbar.close()

    dataset_builder.finalize(os.path.join(output_dir, split + '.text.idx'))
    relations_builder.finalize(
        os.path.join(output_dir, split + '.relations.idx'))
    annotations_list = np.concatenate(annotations_list)
    np.save(os.path.join(output_dir, split + '.annotations'), annotations_list)
Beispiel #22
0
    def __call__(self, path):
        global vocab
        global entities
        num_annotations, num_sentences, num_documents = 0, 0, 0
        total_length = 0
        num_filtered_xao = 0
        num_filtered_by_candidate_set, num_filtered_by_human_annotations, num_filtered_by_self_overlaps = 0, 0, 0
        num_filtered_by_crossing_sentence_boundaries, num_filtered_solo_annotion_in_sentence = 0, 0
        num_filtered_by_entity_vocab = 0

        empty_line_tensor = vocab.encode_line(line='',
                                              append_eos=self.append_eos)
        assert len(empty_line_tensor) == int(self.append_eos)

        if self.entity_vocab is None:
            annotation_entities = Counter()
        else:
            output_prefix = self.generate_tmp_filename()
            dataset_builder = indexed_dataset.make_builder(
                output_prefix + '.text.bin',
                impl=self.dataset_impl,
                vocab_size=len(vocab),
            )
            annotations_list = list()

        with codecs.open(path, 'r', 'utf8') as f:
            for line in f:
                article = json.loads(line[:-1])
                annotations = article['el']
                article[
                    'annotations'], _num_filtered_xao = self.fix_annotations(
                        article['annotations'])
                num_filtered_xao += _num_filtered_xao
                annotations, _num_filtered_xao = self.fix_annotations(
                    annotations)
                num_filtered_xao += _num_filtered_xao
                annotations, _num_filtered_by_candidate_set = self.filter_by_candidate_set(
                    article, annotations)
                annotations, _num_filtered_by_human_annotations = self.filter_by_human_annotations(
                    article, annotations)
                annotations, _num_filtered_by_self_overlaps = self.filter_by_self_overlaps(
                    annotations)
                annotations = article['annotations'] + annotations
                if self.entity_vocab is not None:
                    annotations, _num_filtered_by_entity_vocab = self.filter_by_entity_vocab(
                        annotations)
                    num_filtered_by_entity_vocab += _num_filtered_by_entity_vocab
                num_filtered_by_candidate_set += _num_filtered_by_candidate_set
                num_filtered_by_human_annotations += _num_filtered_by_human_annotations
                num_filtered_by_self_overlaps += _num_filtered_by_self_overlaps

                nlcs = NCLS(*get_intervals(annotations))
                text = article['text'].replace(u'\xa0', u' ')
                offset = 0

                for sentence, offset in self.split_into_sentences(text):
                    sentence_begin = offset
                    sentence_end = offset + len(sentence)
                    assert sentence == text[sentence_begin:sentence_end]

                    annotations_per_sentence = []
                    for annotation_id in nlcs.find_overlap(
                            sentence_begin, sentence_end):
                        annotation = annotations[annotation_id[2]]
                        start, end = get_start_end(annotation)
                        if sentence_begin <= start and end <= sentence_end:
                            annotations_per_sentence.append(annotation)
                        else:
                            num_filtered_by_crossing_sentence_boundaries += 1
                    num_unique_entities = len(
                        set([
                            annotation['uri']
                            for annotation in annotations_per_sentence
                        ]))
                    if num_unique_entities < self.min_entities_per_sentence:
                        num_filtered_solo_annotion_in_sentence += 1
                        continue
                    num_annotations += len(annotations_per_sentence)

                    if self.entity_vocab is None:
                        annotation_entities.update([
                            annotation['uri']
                            for annotation in annotations_per_sentence
                        ])
                    else:
                        annotations_per_sentence = self.set_local_offsets(
                            offset, annotations_per_sentence)
                        fixed_sentence, annotations_per_sentence = self.strip_whitespaces(
                            sentence, annotations_per_sentence)
                        fixed_sentence, annotations_per_sentence = self.strip_double_whitespaces(
                            fixed_sentence, annotations_per_sentence)
                        fixed_sentence, annotations_per_sentence = self.add_margin_to_annotations(
                            fixed_sentence, annotations_per_sentence)
                        annotations_per_sentence = self.get_word_based_offsets(
                            fixed_sentence, annotations_per_sentence)
                        ids, annotations_per_sentence = self.apply_gt2_bpe(
                            fixed_sentence, annotations_per_sentence)

                        ids_tensor = vocab.encode_line(
                            line=' '.join(ids), append_eos=self.append_eos)
                        assert len(ids_tensor) == len(ids) + int(
                            self.append_eos)
                        dataset_builder.add_item(ids_tensor)
                        annotations_list.extend([[
                            x['start_word'] + total_length,
                            x['end_word'] + total_length, num_sentences,
                            num_documents,
                            int(entities[x['uri']])
                        ] for x in annotations_per_sentence])
                        total_length += len(ids_tensor)
                    num_sentences += 1

                if self.entity_vocab is not None:
                    dataset_builder.add_item(empty_line_tensor)
                    total_length += len(empty_line_tensor)
                    num_sentences += 1
                    num_documents += 1

        if self.entity_vocab is not None:
            dataset_builder.finalize(output_prefix + '.text.idx')
            annotations_list = np.array(annotations_list, dtype=np.int64)

        return (
            annotation_entities
            if self.entity_vocab is None else output_prefix,
            annotations_list if self.entity_vocab is not None else None,
            total_length if self.entity_vocab is not None else 0,
            num_documents,
            num_sentences,
            num_annotations,
            num_filtered_by_candidate_set,
            num_filtered_by_human_annotations,
            num_filtered_by_self_overlaps,
            num_filtered_by_crossing_sentence_boundaries,
            num_filtered_solo_annotion_in_sentence,
            num_filtered_xao,
            num_filtered_by_entity_vocab,
        )
Beispiel #23
0
def main(args):
    assert os.path.isdir(args.tmp)
    input_files = sorted([
        path for data in args.data.split(',')
        for path in glob.glob(os.path.expanduser(data))
    ])
    print('-- Found %d files' % len(input_files))
    build_entity_vocab_mode = not os.path.exists(args.entity_vocab)
    print('-- Build entity vocab mode: %s' %
          ('ON' if build_entity_vocab_mode else 'OFF'))

    processor = WikiProcessor(
        args.roberta,
        args.limit_set_of_entities,
        args.min_entities_per_sentence,
        args.dataset_impl,
        args.tmp,
        args.append_eos,
        args.entity_vocab if not build_entity_vocab_mode else None,
    )
    num_documents, num_sentences = 0, 0
    num_annotations, num_filtered_by_candidate_set, num_filtered_by_human_annotations, num_filtered_by_self_overlaps = 0, 0, 0, 0
    num_filtered_xao, num_filtered_by_crossing_sentence_boundaries, num_filtered_solo_annotion_in_sentence = 0, 0, 0
    num_filtered_by_entity_vocab = 0
    total_length = 0

    pbar = tqdm.tqdm(
        total=len(input_files),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )
    pbar.set_postfix({
        's': num_sentences,
        'd': num_documents,
        'ann': num_annotations,
        'f_ed': num_filtered_by_candidate_set,
        'f_h_overlap': num_filtered_by_human_annotations,
        'f_self_overlap': num_filtered_by_self_overlaps,
        'f_cross_s_bd': num_filtered_by_crossing_sentence_boundaries,
        'f_solo_s': num_filtered_solo_annotion_in_sentence,
        'f_xao': num_filtered_xao,
        'f_vocab': num_filtered_by_entity_vocab,
    })

    if build_entity_vocab_mode:
        entities = Counter()
    else:
        vocab = Dictionary.load('/data/urikz/nki/roberta/dict.txt')
        dataset_builder = indexed_dataset.make_builder(
            args.output + '.text.bin',
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        entities = load_entities(args.entity_vocab)
        annotations_list = list()

    if args.nworkers == 1:
        processor.initializer()
        for output, _annotations_list, _total_length, _num_documents, s, x, y, z, w, v, u, t, q in map(
                processor, input_files):
            if build_entity_vocab_mode:
                entities.update(output)
            else:
                dataset_builder.merge_file_(output + '.text')
                _annotations_list[:, 0] += total_length
                _annotations_list[:, 1] += total_length
                _annotations_list[:, 2] += num_sentences
                _annotations_list[:, 3] += num_documents
                annotations_list.append(_annotations_list)
            total_length += _total_length
            num_documents += _num_documents
            num_sentences += s
            num_annotations += x
            num_filtered_by_candidate_set += y
            num_filtered_by_human_annotations += z
            num_filtered_by_self_overlaps += w
            num_filtered_by_crossing_sentence_boundaries += v
            num_filtered_solo_annotion_in_sentence += u
            num_filtered_xao += t
            num_filtered_by_entity_vocab += q
            pbar.set_postfix({
                's': num_sentences,
                'd': num_documents,
                'ann': num_annotations,
                'f_ed': num_filtered_by_candidate_set,
                'f_h_overlap': num_filtered_by_human_annotations,
                'f_self_overlap': num_filtered_by_self_overlaps,
                'f_cross_s_bd': num_filtered_by_crossing_sentence_boundaries,
                'f_solo_s': num_filtered_solo_annotion_in_sentence,
                'f_xao': num_filtered_xao,
                'f_vocab': num_filtered_by_entity_vocab,
            })
            pbar.update()
    else:
        with mp.Pool(processes=args.nworkers,
                     initializer=processor.initializer) as pool:
            for output, _annotations_list, _total_length, _num_documents, s, x, y, z, w, v, u, t, q in pool.imap_unordered(
                    processor, input_files):
                if build_entity_vocab_mode:
                    entities.update(output)
                else:
                    dataset_builder.merge_file_(output + '.text')
                    _annotations_list[:, 0] += total_length
                    _annotations_list[:, 1] += total_length
                    _annotations_list[:, 2] += num_sentences
                    _annotations_list[:, 3] += num_documents
                    annotations_list.append(_annotations_list)
                total_length += _total_length
                num_documents += _num_documents
                num_sentences += s
                num_annotations += x
                num_filtered_by_candidate_set += y
                num_filtered_by_human_annotations += z
                num_filtered_by_self_overlaps += w
                num_filtered_by_crossing_sentence_boundaries += v
                num_filtered_solo_annotion_in_sentence += u
                num_filtered_xao += t
                num_filtered_by_entity_vocab += q
                pbar.set_postfix({
                    's': num_sentences,
                    'd': num_documents,
                    'ann': num_annotations,
                    'f_ed': num_filtered_by_candidate_set,
                    'f_h_overlap': num_filtered_by_human_annotations,
                    'f_self_overlap': num_filtered_by_self_overlaps,
                    'f_cross_s_bd':
                    num_filtered_by_crossing_sentence_boundaries,
                    'f_solo_s': num_filtered_solo_annotion_in_sentence,
                    'f_xao': num_filtered_xao,
                    'f_vocab': num_filtered_by_entity_vocab,
                })
                pbar.update()
    pbar.close()

    if build_entity_vocab_mode:
        counter = 0
        with codecs.open(args.entity_vocab, 'w', 'utf8') as f:
            for entity_and_count in entities.most_common():
                if (args.entity_count_threshold is None
                        or entity_and_count[1] >= args.entity_count_threshold):
                    counter += 1
                    f.write('%s %d\n' %
                            (entity_and_count[0], entity_and_count[1]))
        print('-- Successfully saved %d entities (out of %d) to %s' % (
            counter,
            len(entities),
            args.entity_vocab,
        ))
    else:
        dataset_builder.finalize(args.output + '.text.idx')
        annotations_list = np.concatenate(annotations_list)
        np.save(args.output + '.annotations', annotations_list)
Beispiel #24
0
def main(args):
    data_path = os.path.join(args.root_dir, args.split + '.json')
    annotation_path = os.path.join(args.root_dir,
                                   args.split + '_annotations.json')
    with codecs.open(data_path, 'r', 'utf8') as f:
        data = json.load(f)["Data"]

    with codecs.open(annotation_path, 'r', 'utf8') as f:
        annotations = json.load(f)
    print('-- Loaded data from %s' % data_path)

    entity_dict = EntityDictionary.load(args.entity_path)

    processor = TriviaQAProcessor(args.roberta_dir, entity_dict,
                                  args.dataset_impl, args.append_eos,
                                  args.max_positions)
    pbar = tqdm.tqdm(
        total=len(data),
        desc='Processing Wiki',
        bar_format=TRAINING_TQDM_BAD_FORMAT,
    )

    output_dir = os.path.join(args.root_dir, 'bin')
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    split = args.split if args.split != 'dev' else 'valid'
    vocab = Dictionary.load(
        os.path.join(args.roberta_dir, 'roberta.base', 'dict.txt'))
    question_builder = indexed_dataset.make_builder(
        os.path.join(output_dir, split + '.questions_entities' + '.bin'),
        impl=args.dataset_impl,
        vocab_size=len(vocab),
    )
    answer_entities = []
    processed_annotations = []

    processor.initializer()
    num_questions, valid_questions, no_annotation, not_in_dict, processing_problem = 0, 0, 0, 0, 0
    for i, sample in enumerate(data):
        num_questions += 1
        pbar.update()

        question = sample["Question"]
        answer = sample["Answer"]
        annotation = copy.deepcopy(annotations[i])

        entity_name = None

        # Only use questions with an entity for the answer
        if "MatchedWikiEntityName" in answer:
            entity_name = answer["MatchedWikiEntityName"]
        else:
            continue

        if annotation is None:
            no_annotation += 1
            continue

        entity_name = entity_name.replace(' ', '_')

        # assert entity_name in entity_dict
        if not entity_name in entity_dict:
            not_in_dict += 1
            continue
        entity_id = entity_dict.index(entity_name)

        ids_tensor, processed_annotation = processor.process(
            question, annotation)
        if ids_tensor is None:
            processing_problem += 1
            continue
        processed_annotations.append(processed_annotation)
        answer_entities.append(entity_id)

        question_builder.add_item(ids_tensor)
        valid_questions += 1

    pbar.close()

    question_builder.finalize(
        os.path.join(output_dir, split + '.questions_entities' + '.idx'))
    np.save(os.path.join(output_dir, split + '.answer_entities'),
            answer_entities)

    processed_annotation_path = os.path.join(
        output_dir, args.split + '.processed_annotations.json')
    with codecs.open(processed_annotation_path, 'w', 'utf8') as f:
        json.dump(processed_annotations, f, indent=4)