def add_file_to_dictionary(filename, dict, tokenize, num_workers): def merge_result(counter): for w, c in sorted(counter.items()): dict.add_symbol(w, c) local_file = PathManager.get_local_path(filename) offsets = find_offsets(local_file, num_workers) if num_workers > 1: chunks = zip(offsets, offsets[1:]) pool = Pool(processes=num_workers) results = [] for (start_offset, end_offset) in chunks: results.append( pool.apply_async( Dictionary._add_file_to_dictionary_single_worker, ( local_file, tokenize, dict.eos_word, start_offset, end_offset, ), ) ) pool.close() pool.join() for r in results: merge_result(r.get()) else: merge_result( Dictionary._add_file_to_dictionary_single_worker( local_file, tokenize, dict.eos_word, offsets[0], offsets[1] ) )
def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): nseq = [0] def merge_result(worker_result): nseq[0] += worker_result["nseq"] input_file = input_prefix offsets = find_offsets(input_file, num_workers) (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id, (start_offset, end_offset) in enumerate(more_chunks, start=1): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize_alignments, ( args, input_file, utils.parse_alignment, prefix, start_offset, end_offset, ), callback=merge_result, ) pool.close() ds = indexed_dataset.make_builder(dataset_dest_file( args, output_prefix, None, "bin"), impl=args.dataset_impl) merge_result( Binarizer.binarize_alignments( input_file, utils.parse_alignment, lambda t: ds.add_item(t), offset=first_chunk[0], end=first_chunk[1], )) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, None) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) logger.info("[alignments] {}: parsed {} alignments".format( input_file, nseq[0]))
def test_readchunks(self): from fairseq.file_chunker_utils import Chunker, find_offsets offsets = find_offsets(self._tmpfile, self._num_splits) for start, end in zip(offsets, offsets[1:]): with Chunker(self._tmpfile, start, end) as lines: all_lines = list(lines) num_lines = self._num_lines / self._num_splits self.assertAlmostEqual( len(all_lines), num_lines, delta=1 ) # because we split on the bites, we might end up with one more/less line in a chunk self.assertListEqual( all_lines, [self._line_content for _ in range(len(all_lines))])
def test_find_offsets(self): from fairseq.file_chunker_utils import find_offsets offsets = find_offsets(self._tmpfile, self._num_splits) self.assertEqual(len(offsets), self._num_splits + 1) (zero, *real_offsets, last) = offsets self.assertEqual(zero, 0) for i, o in enumerate(real_offsets): self.assertEqual( o, self._num_bytes + ((i + 1) * self._num_bytes * self._num_lines / self._num_splits), ) self.assertEqual(last, self._num_bytes * self._num_lines)
def multiprocess_dataset( cls, input_file: str, dataset_impl: str, binarizer: Binarizer, output_prefix: str, vocab_size=None, num_workers=1, ) -> BinarizeSummary: final_summary = BinarizeSummary() offsets = find_offsets(input_file, num_workers) # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs: # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info # we zip the list with itself shifted by one to get all the pairs. (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) worker_results = [ pool.apply_async( cls._binarize_chunk_and_finalize, args=( binarizer, input_file, start_offset, end_offset, _worker_prefix( output_prefix, worker_id, ), dataset_impl, ), kwds={ "vocab_size": vocab_size, } if vocab_size is not None else {}, ) for worker_id, (start_offset, end_offset) in enumerate( more_chunks, start=1 ) ] pool.close() pool.join() for r in worker_results: summ = r.get() final_summary.merge(summ) # do not close the bin file as we need to merge the worker results in final_ds, summ = cls._binarize_file_chunk( binarizer, input_file, offset_start=first_chunk[0], offset_end=first_chunk[1], output_prefix=output_prefix, dataset_impl=dataset_impl, vocab_size=vocab_size if vocab_size is not None else None, ) final_summary.merge(summ) if num_workers > 1: for worker_id in range(1, num_workers): # merge the worker outputs worker_output_prefix = _worker_prefix( output_prefix, worker_id, ) final_ds.merge_file_(worker_output_prefix) try: os.remove(indexed_dataset.data_file_path(worker_output_prefix)) os.remove(indexed_dataset.index_file_path(worker_output_prefix)) except Exception as e: logger.error( f"couldn't remove {worker_output_prefix}.*", exc_info=e ) # now we can close the file idx_file = indexed_dataset.index_file_path(output_prefix) final_ds.finalize(idx_file) return final_summary
def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) n_seq_tok = [0, 0] replaced = Counter() def merge_result(worker_result): replaced.update(worker_result["replaced"]) n_seq_tok[0] += worker_result["nseq"] n_seq_tok[1] += worker_result["ntok"] input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "") offsets = find_offsets(input_file, num_workers) (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id, (start_offset, end_offset) in enumerate(more_chunks, start=1): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize, ( args, input_file, vocab, prefix, lang, start_offset, end_offset, ), callback=merge_result, ) pool.close() ds = indexed_dataset.make_builder( dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl, vocab_size=len(vocab), ) merge_result( Binarizer.binarize( input_file, vocab, lambda t: ds.add_item(t), offset=first_chunk[0], end=first_chunk[1], )) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, lang) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) logger.info( "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( lang, input_file, n_seq_tok[0], n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1], vocab.unk_word, ))