예제 #1
0
    def add_file_to_dictionary(filename, dict, tokenize, num_workers):
        def merge_result(counter):
            for w, c in sorted(counter.items()):
                dict.add_symbol(w, c)

        local_file = PathManager.get_local_path(filename)
        offsets = find_offsets(local_file, num_workers)
        if num_workers > 1:
            chunks = zip(offsets, offsets[1:])
            pool = Pool(processes=num_workers)
            results = []
            for (start_offset, end_offset) in chunks:
                results.append(
                    pool.apply_async(
                        Dictionary._add_file_to_dictionary_single_worker,
                        (
                            local_file,
                            tokenize,
                            dict.eos_word,
                            start_offset,
                            end_offset,
                        ),
                    )
                )
            pool.close()
            pool.join()
            for r in results:
                merge_result(r.get())
        else:
            merge_result(
                Dictionary._add_file_to_dictionary_single_worker(
                    local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
                )
            )
예제 #2
0
    def make_binary_alignment_dataset(input_prefix, output_prefix,
                                      num_workers):
        nseq = [0]

        def merge_result(worker_result):
            nseq[0] += worker_result["nseq"]

        input_file = input_prefix
        offsets = find_offsets(input_file, num_workers)
        (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id, (start_offset, end_offset) in enumerate(more_chunks,
                                                                   start=1):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize_alignments,
                    (
                        args,
                        input_file,
                        utils.parse_alignment,
                        prefix,
                        start_offset,
                        end_offset,
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(dataset_dest_file(
            args, output_prefix, None, "bin"),
                                          impl=args.dataset_impl)

        merge_result(
            Binarizer.binarize_alignments(
                input_file,
                utils.parse_alignment,
                lambda t: ds.add_item(t),
                offset=first_chunk[0],
                end=first_chunk[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, None)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, None, "idx"))

        logger.info("[alignments] {}: parsed {} alignments".format(
            input_file, nseq[0]))
예제 #3
0
    def test_readchunks(self):
        from fairseq.file_chunker_utils import Chunker, find_offsets

        offsets = find_offsets(self._tmpfile, self._num_splits)
        for start, end in zip(offsets, offsets[1:]):
            with Chunker(self._tmpfile, start, end) as lines:
                all_lines = list(lines)
                num_lines = self._num_lines / self._num_splits
                self.assertAlmostEqual(
                    len(all_lines), num_lines, delta=1
                )  # because we split on the bites, we might end up with one more/less line in a chunk
                self.assertListEqual(
                    all_lines,
                    [self._line_content for _ in range(len(all_lines))])
예제 #4
0
    def test_find_offsets(self):
        from fairseq.file_chunker_utils import find_offsets

        offsets = find_offsets(self._tmpfile, self._num_splits)
        self.assertEqual(len(offsets), self._num_splits + 1)
        (zero, *real_offsets, last) = offsets
        self.assertEqual(zero, 0)
        for i, o in enumerate(real_offsets):
            self.assertEqual(
                o,
                self._num_bytes + ((i + 1) * self._num_bytes *
                                   self._num_lines / self._num_splits),
            )
        self.assertEqual(last, self._num_bytes * self._num_lines)
예제 #5
0
파일: binarizer.py 프로젝트: tma15/fairseq
    def multiprocess_dataset(
        cls,
        input_file: str,
        dataset_impl: str,
        binarizer: Binarizer,
        output_prefix: str,
        vocab_size=None,
        num_workers=1,
    ) -> BinarizeSummary:
        final_summary = BinarizeSummary()

        offsets = find_offsets(input_file, num_workers)
        # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs:
        # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info
        # we zip the list with itself shifted by one to get all the pairs.
        (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            worker_results = [
                pool.apply_async(
                    cls._binarize_chunk_and_finalize,
                    args=(
                        binarizer,
                        input_file,
                        start_offset,
                        end_offset,
                        _worker_prefix(
                            output_prefix,
                            worker_id,
                        ),
                        dataset_impl,
                    ),
                    kwds={
                        "vocab_size": vocab_size,
                    }
                    if vocab_size is not None
                    else {},
                )
                for worker_id, (start_offset, end_offset) in enumerate(
                    more_chunks, start=1
                )
            ]

            pool.close()
            pool.join()
            for r in worker_results:
                summ = r.get()
                final_summary.merge(summ)

        # do not close the bin file as we need to merge the worker results in
        final_ds, summ = cls._binarize_file_chunk(
            binarizer,
            input_file,
            offset_start=first_chunk[0],
            offset_end=first_chunk[1],
            output_prefix=output_prefix,
            dataset_impl=dataset_impl,
            vocab_size=vocab_size if vocab_size is not None else None,
        )
        final_summary.merge(summ)

        if num_workers > 1:
            for worker_id in range(1, num_workers):
                # merge the worker outputs
                worker_output_prefix = _worker_prefix(
                    output_prefix,
                    worker_id,
                )
                final_ds.merge_file_(worker_output_prefix)
                try:
                    os.remove(indexed_dataset.data_file_path(worker_output_prefix))
                    os.remove(indexed_dataset.index_file_path(worker_output_prefix))
                except Exception as e:
                    logger.error(
                        f"couldn't remove {worker_output_prefix}.*", exc_info=e
                    )

        #  now we can close the file
        idx_file = indexed_dataset.index_file_path(output_prefix)
        final_ds.finalize(idx_file)
        return final_summary
예제 #6
0
    def make_binary_dataset(vocab, input_prefix, output_prefix, lang,
                            num_workers):
        logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
        n_seq_tok = [0, 0]
        replaced = Counter()

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        input_file = "{}{}".format(input_prefix,
                                   ("." + lang) if lang is not None else "")
        offsets = find_offsets(input_file, num_workers)
        (first_chunk, *more_chunks) = zip(offsets, offsets[1:])
        pool = None
        if num_workers > 1:
            pool = Pool(processes=num_workers - 1)
            for worker_id, (start_offset, end_offset) in enumerate(more_chunks,
                                                                   start=1):
                prefix = "{}{}".format(output_prefix, worker_id)
                pool.apply_async(
                    binarize,
                    (
                        args,
                        input_file,
                        vocab,
                        prefix,
                        lang,
                        start_offset,
                        end_offset,
                    ),
                    callback=merge_result,
                )
            pool.close()

        ds = indexed_dataset.make_builder(
            dataset_dest_file(args, output_prefix, lang, "bin"),
            impl=args.dataset_impl,
            vocab_size=len(vocab),
        )
        merge_result(
            Binarizer.binarize(
                input_file,
                vocab,
                lambda t: ds.add_item(t),
                offset=first_chunk[0],
                end=first_chunk[1],
            ))
        if num_workers > 1:
            pool.join()
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_prefix, worker_id)
                temp_file_path = dataset_dest_prefix(args, prefix, lang)
                ds.merge_file_(temp_file_path)
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))

        ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx"))

        logger.info(
            "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format(
                lang,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
                100 * sum(replaced.values()) / n_seq_tok[1],
                vocab.unk_word,
            ))