def _ray_process_aux(self, tasks, results_queue): chunks = ichunked(tasks, int(self.chunksize)) num_initial = int(ray.available_resources()['CPU']) with self._get_processing_bar(len(tasks)) as progress_bar: futures = [ self._submit_chunk_ray(c) for c in it.islice(chunks, num_initial) ] while futures: (finished, *_), rest = ray.wait(futures, num_returns=1) result = ray.get(finished) results_queue.put(result) progress_bar.update(len(result)) try: chunk = next(chunks) except StopIteration: ... else: rest.append(self._submit_chunk_ray(chunk)) futures = rest results_queue.put(None)
def conv_images(self, oid_images: OidImages) -> CocoImages: try: n_iter = math.ceil(len(oid_images.images) / os.cpu_count()) oid_images_itr = ichunked(oid_images.images.values(), n_iter) futures = [] with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor: for oid_images_islice in oid_images_itr: future = executor.submit(self._conv_image, OidConv.image_id, oid_images_islice) futures.append(future) OidConv.image_id += n_iter OidConv.image_id -= n_iter - (len(oid_images.images) % n_iter) concurrent.futures.wait(futures, timeout=None) coco_images = CocoImages(oid_images.ds_type, oid_images.images_dir) for future in concurrent.futures.as_completed(futures): coco_images = self._concat_coco_images(coco_images, future.result()) self.oid_images = oid_images self.coco_images = coco_images return coco_images except Exception as e: logger.error(f'action=conv_images error={e}') raise
async def _batch_execution_async( batch_func: Callable[..., Awaitable], arg_list: Iterable[Any], *batch_func_args, **batch_func_kwargs, ) -> Sequence[Exception]: import tqdm.asyncio # type: ignore exceptions = [] loop = asyncio.get_event_loop() for batch, args_chunk in enumerate( more_itertools.ichunked(arg_list, CONCURRENT_TASKS), start=1 ): tasks = [ loop.create_task( batch_func(arg, *batch_func_args, **batch_func_kwargs) ) for arg in args_chunk ] for coro in tqdm.asyncio.tqdm_asyncio.as_completed( tasks, desc=f"Progress batch {batch}", file=sys.stdout ): try: await coro except exception.GitError as exc: exceptions.append(exc) for e in exceptions: plug.log.error(str(e)) return exceptions
def run(self): self.setup() ids = [] codes = [] filename_iter = ichunked(self.gen_filenames(), self.args.batch_size) for batch_i, in_filenames in tqdm.tqdm(enumerate(filename_iter)): # load up images from a batch of filenames batch_imgs = [] batch_filenames = [] for in_filename in in_filenames: try: target_img = Image.open(in_filename).convert('RGB') except IOError: print(f'Error opening {in_filename}') continue target_img = self.transform_input_image(target_img) batch_imgs.append(target_img) batch_filenames.append(os.path.basename(in_filename)) _, p_code = self.d(torch.stack(batch_imgs).to(self.args.device)) batch_ids = [os.path.splitext(f)[0] for f in batch_filenames] ids += batch_ids codes.append(p_code.cpu()) if self.args.recon: recon = self.g(p_code) self.save_image( recon, f'{self.args.output_prefix}_{batch_i}.png' ) codes = [c.numpy() for c in codes] self.save_codes(ids, codes)
def _ray_process(self, tasks): tasks = iter(tasks) futures = [] chunks = ichunked(tasks, int(self.chunksize)) for chunk in chunks: chunk = [self._load_task_bin(t) for t in chunk] futures.append(self.process_chunk_ray.remote(chunk)) if len(futures) >= self.num_cpus + 4: break while futures: finished, rest = ray.wait(futures, num_returns=1) results = ray.get(finished[0]) for result in results: yield result try: chunk = next(chunks) chunk = [self._load_task_bin(t) for t in chunk] rest.append(self.process_chunk_ray.remote(chunk)) except StopIteration: ... futures = rest
def addUnmappedBaseTag(BAM_PATH, NANOPORE_FASTA, BAM_PATH_OUT, IN_DISK): global bamFile bamFile = pysam.AlignmentFile(BAM_PATH, "rb") bamFileOut = pysam.AlignmentFile(BAM_PATH_OUT, "wbu", template=bamFile) if not IN_DISK: allFasta = FastaContent(NANOPORE_FASTA, False) allFastaDict = {} for i, x in enumerate(allFasta.iter()): allFastaDict[x.name] = x else: allFastaDict = FastaContent(NANOPORE_FASTA, True) allReadGenerate = mlt.chunked(bamFile, 40000) splicedReadGenerate = mlt.ichunked(allReadGenerate, 8) splicedResult = [] for singleRawChunk in splicedReadGenerate: with ThreadPoolExecutor(max_workers=2) as multiT: multiT.submit(outputProcessedRead, bamFileOut, splicedResult) splicedResult = multiT.submit(bamProcess, singleRawChunk, allFastaDict).result() outputProcessedRead(bamFileOut, splicedResult) bamFileOut.close() pysam.index(BAM_PATH_OUT)
async def flat(self, ctx, levelname, biome, *blocks: commands.Greedy[Union[int, str]]): """カスタムされたフラットワールドを生成します。 このコマンドによって生成されるワールドはBedrockEdition限定です。 <blocks>は<ブロックID:データ値> <ブロック数>このように表記してください。 また、データ値が0の場合は省略可能です。 バイオームIDはwiki等を参照してください。 https://minecraft.gamepedia.com/Biome#Biome_IDs サンプル .flat testflat-world plains bedrock:0 1 dirt 2 grass 1 """ if len(blocks) % 2 > 0: await ctx.send("ブロックの指定が間違っています。") return block_data = [] for _block, count in ichunked(blocks, 2): block = get_block(_block) block_data.append({ "block_name": block.id, "block_data": block.data, "count": count }) level_data = create_level_data(block_data, biome, levelname) mcworld = create_mcworld(level_data) blocks_text = "\n".join([ "レイヤー{0}: ID: {1[block_name]}, データ値: {1[block_data]}, 個数: {1[count]}" .format(i + 1, v) for i, v in enumerate(block_data) ]) await ctx.send(f"ワールド名: {levelname}\nバイオーム: {biome}\n{blocks_text}", file=discord.File(mcworld, filename="FlatWorld.mcworld"))
def chunk_upload_tmp(ssh_conn, tmpfile, content, chunksize=1024): chunked = more_itertools.ichunked(content, chunksize) for c in chunked: chunk = ''.join(c) cmd = f'echo "{chunk}" >> {tmpfile}' logging.info(cmd) ssh_conn.exec_command(cmd)
def _process(self, input_pack: DataPack): # handle existing entries self._process_existing_entries(input_pack) batch_size: int = self.configs["infer_batch_size"] batches: Iterator[Iterator[Sentence]] # Need a copy of the one-pass iterators to support a second loop on # them. All other ways around it like using `itertools.tee` and `list` # would require extra storage conflicting with the idea of using # iterators in the first place. `more_itertools.ichunked` uses # `itertools.tee` under the hood but our usage (reading iterators # in order) does not cause memory issues. batches_copy: Iterator[Iterator[Sentence]] if batch_size <= 0: batches = iter([input_pack.get(Sentence)]) batches_copy = iter([input_pack.get(Sentence)]) else: batches = more_itertools.ichunked(input_pack.get(Sentence), batch_size) batches_copy = more_itertools.ichunked(input_pack.get(Sentence), batch_size) for sentences, sentences_copy in zip(batches, batches_copy): inputs: List[Dict[str, str]] = [{ "sentence": s.text } for s in sentences] results: Dict[str, List[Dict[str, Any]]] = { k: p.predict_batch_json(inputs) for k, p in self.predictor.items() } for i, sent in enumerate(sentences_copy): result: Dict[str, List[str]] = {} for key in self.predictor: if key == "srl": result.update( parse_allennlp_srl_results( results[key][i]["verbs"])) else: result.update(results[key][i]) if "tokenize" in self.configs.processors: # creating new tokens and dependencies tokens = self._create_tokens(input_pack, sent, result) if "depparse" in self.configs.processors: self._create_dependencies(input_pack, tokens, result) if "srl" in self.configs.processors: self._create_srl(input_pack, tokens, result)
def get_frames(strand: DNA, orfs: List[str]): start_codon_positions = strand.search_for_motif('ATG', base=0) for p in start_codon_positions: frame = strand.sequence[p:] assert frame.startswith('ATG') codons_in_frame = set(map("".join, ichunked(frame, 3))) # only stop-codons aligned in frame work if 'TAG' in codons_in_frame or 'TGA' in codons_in_frame or 'TAA' in codons_in_frame: orfs.append(frame)
def generate_auth_token(length: int = 20, chunk_size: int = 5) -> str: """ We use the convenient APIs added in Python 3.6 for generating cryptographically strong random numbers suitable for authentication tokens. See https://docs.python.org/3/library/secrets.html """ alphabet = string.ascii_uppercase + string.digits chars = (secrets.choice(alphabet) for _ in range(length)) chunks = ("".join(chunk) for chunk in ichunked(chars, chunk_size)) return "-".join(chunks)
def fetch(fasta, directory: Path): ids = extract_ids(fasta) path = filename(directory, ids) if path.exists(): LOGGER.info("Dump file already exists, skipping") return path with path.open('w') as output: chunked = more.ichunked(ids, SIZE) for chunk in chunked: embl = query_ncbi(chunk) shutil.copyfileobj(embl, output) validate(path, ids) return path
def es_ingest_file(file_path, es_host, es_index_name, batch_size): with open(file_path, "rt") as file: # Streaming multiple GB's to the ES "/_bulk" endpoint in a single request causes memory issues in ES # Making one request per log-line on the other hand seems slow. file_itor = more_itertools.ichunked(file, batch_size) for lines in file_itor: es_bulk_command = generate_bulk_index_command(es_index_name, lines) headers = { 'content-type' : 'application/json' } response = requests.post("{}/_bulk".format(es_host), data=es_bulk_command, headers=headers) if response.status_code == 429: # Too Many Requests logging.error("Response status: 429 - Too Many Requests. Try to lower the batch size or increase Java's available memory.") response.raise_for_status()
def translate_to_protein(self) -> Peptide: """ :return: Returns protein chain encoded by matrix RNA, using 1-letter notation. Translating stops at stop-codon """ # divide into 3-character strings codons = map("".join, ichunked(self.sequence, 3)) # translate into aminoacids peptide_seq = (GENETIC_CODE[codon] for codon in codons) # transation stops when stop-codon is encountered peptide_seq = takewhile(lambda amino: amino != 'X', peptide_seq) # generator is joined into a string peptide_seq = ''.join(peptide_seq) protein = Peptide(peptide_seq) return protein
def _print(self, testcases, process_id, dryrun, testcases_per_file): """ Used by a single process print testcases to files. Args: testcases (iterable): An iterator of testcases. process_id (int): The processe id. dryrun (bool): Whether dryrun mode is enabled. testcases_per_file (int, optional): The maximum number of testcases that can be printing int a single file. """ chunks = ichunked(testcases, testcases_per_file) for i, chunk in enumerate(chunks): basename = f'testcase-{self.machine_index}-{process_id}' filename = f'tmp-{basename}' if dryrun else f'{basename}-{i}' data = [Format.make(self, x) for x in chunk if self.filter(x)] with open(join(self.folder_path, filename), 'a') as f: f.write(''.join(data))
def crop_and_convert(self) -> None: self.crop() self.set_pdf_pages() paired_page_range = ichunked(self.pdf_pages_lim, self.n_up) proc_funcs = [] for i, pp_islice in enumerate(paired_page_range): proc_mthd = self.process_page_pair proc_func = partial(proc_mthd, page_pair=tuple(pp_islice), pair_idx=i) proc_funcs.append(proc_func) if self.multicore: n_cores_kwarg = {"n_cores": v for v in [self.cores] if v} batch_multiprocess(proc_funcs, show_progress=True, **n_cores_kwarg) else: sequential_process(proc_funcs, show_progress=True) return
def detect_and_align_faces(dataset_folder, destination_folder): """""" import timing mp.set_start_method("spawn") # gc.collect() dataset_folder = _dataset_generator(dataset_folder, destination_folder) # dataset_folder = it.islice(dataset_folder) batches = ichunked(dataset_folder, 256) _preprocess = partial(_preprocess_pipeline, destination_folder=DESTINATION_FOLDER) # _preprocess(dataset_folder) for batch in batches: # gc.freeze() with futures.ProcessPoolExecutor() as executor: executor.map(_preprocess, batch)
def parse_game(line: str) -> Tuple[List[game.GameState], int]: """ Parse a Logistello game into a list of game states and a final score. The score is an "absolute difference" score from Black's perspective. """ state = game.starting_state() states = [state] board_str, score_str, _ = line.split() for [player_str, *move] in more_itertools.ichunked(board_str[:-1], 3): player = game.Player.BLACK if player_str == "+" else game.Player.WHITE row, col = game.parse_move(move) # type: ignore if player != state.active_player: state = state.apply_pass() states.append(state) state = state.apply_move(row, col) states.append(state) return states, int(score_str)
def minibatch_generator( self, series: DataSeries, batch_size: int, metadata: Dict[str, Any], should_shuffle: bool, drop_incomplete_batches: bool = False ) -> Generator[DefaultDict[str, List[Any]], None, None]: """ Generates minibatches for the given dataset. Each minibatch is expressed as a feed dict with string keys. These keys must be translated to placeholder tensors before passing the dictionary as an input to Tensorflow. Args: series: The series to generate batches for. batch_size: The minibatch size. metadata: Metadata used during tensorization should_shuffle: Whether the data samples should be shuffled drop_incomplete_batches: Whether incomplete batches should be omitted. This usually applies exclusively to the final minibatch. Returns: A generator a feed dicts, each one representing an entire minibatch. """ data_series = self.dataset[series] # Load dataset if needed if not data_series.is_loaded: data_series.load() # Create iterator over the data data_iterator = data_series.iterate(should_shuffle=should_shuffle, batch_size=batch_size) # Set training flag is_train = series == DataSeries.TRAIN # Generate minibatches for minibatch in ichunked(data_iterator, batch_size): # Turn minibatch into a feed dict feed_dict: DefaultDict[str, List[Any]] = defaultdict(list) num_samples = 0 for sample in minibatch: tensorized_sample = self.tensorize(sample, metadata, is_train=is_train) # Ensure that there are no NoneType or NaN values in the tensorized sample should_include = True for tensor in tensorized_sample.values(): if isinstance(tensor, list) or isinstance( tensor, np.ndarray): tensor_array = np.array(tensor) if np.any(np.isnan(tensor_array)) or np.any( tensor_array == None): should_include = False else: if tensor is None or np.isnan(tensor): should_include = False if not should_include: break # Only include validated samples if should_include: for key, tensor in tensorized_sample.items(): feed_dict[key].append(tensor) num_samples += 1 if drop_incomplete_batches and num_samples < batch_size: continue yield feed_dict
from itertools import count from more_itertools import ichunked all_chunks = ichunked(count(), 4) print(all_chunks) c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) print(c_1) print(*c_1) print(*c_2) c1, c2 = [next(all_chunks) for _ in range(0, 2)] print(list(c1)) print(*c2)
def _depth_first_search(_update, db, branches, parallel, infos, info_init, verbose, mpi_rank, mpi_size, **kwargs): # Define parallel core def _parallel_core(branches, db=None): # Initialize db if db is None: db = kwargs['db_init']() # Convert to list branches = list(branches) # Initialize infos infos = info_init() # Explore all branches while branches: # Get new branches (_new_ps, _new_ph), _new_branches = _update(*branches.pop()) # Collect results kwargs['collect'](db, kwargs['transform'](_new_ps), _new_ph) # Update branches branches.extend(_new_branches) # Update infos infos['largest_n_branches_in_memory'] = max( len(branches), infos['largest_n_branches_in_memory']) # Update infos infos['n_explored_branches'] += 1 return db, infos # If no parallelization is requires, explore branchces one by one if parallel == 1: from more_itertools import ichunked # Get number of chunks chunk_size = max(1, len(branches) // 100) for _bs in tqdm(ichunked(branches, chunk_size), total=len(branches) // chunk_size, desc=f'Mem={virtual_memory().percent}%', disable=not verbose): # Update database and infos db, _infos = _parallel_core(_bs, db) # Update infos infos['n_explored_branches'] += _infos['n_explored_branches'] infos['largest_n_branches_in_memory'] = max( _infos['largest_n_branches_in_memory'], infos['largest_n_branches_in_memory']) # Otherwise, distribute workload among different cores else: with globalize(_parallel_core) as _parallel_core, Pool( parallel) as pool: # Apply async _fps = [ pool.apply_async(_parallel_core, (_branches,)) for _branches in distribute(kwargs['n_chunks'], branches) ] _status = [False] * len(_fps) with tqdm(total=len(_fps), desc=f'Mem={virtual_memory().percent}%', disable=not verbose) as pbar: _pending = len(_fps) while _pending: # Wait sleep(kwargs['sleep_time']) # Activate/disactivate if verbose: pbar.disable = int(time()) % mpi_size != mpi_rank # Get virtual memory _vm = virtual_memory() _pending = 0 for _x, (_p, _s) in enumerate(zip(_fps, _status)): if not _p.ready(): _pending += 1 elif not _s: # Collect data _new_db, _infos = _p.get() # Merge datasets kwargs['merge'](db, _new_db) # Clear dataset _new_db.clear() # Update infos infos['n_explored_branches'] += _infos[ 'n_explored_branches'] infos['largest_n_branches_in_memory'] = max( _infos['largest_n_branches_in_memory'], infos['largest_n_branches_in_memory']) # Set status _status[_x] = True # Update pbar if verbose: pbar.set_description( (f'[{mpi_rank}] ' if mpi_size > 1 else '') + \ f'Mem={_vm.percent}%, ' + \ f'NThreads={infos["n_threads"]}, ' + \ f'NCPUs={infos["n_cpus"]}, ' + \ f'LoadAvg={getloadavg()[0]/infos["n_cpus"]*100:1.2f}%, ' + \ f'NBranches={infos["n_explored_branches"]}' ) pbar.n = len(_fps) - _pending pbar.refresh() # Update infos infos['average_virtual_memory (GB)'] = ( infos['average_virtual_memory (GB)'][0] + _vm.used / 2**30, infos['average_virtual_memory (GB)'][1] + 1) infos['peak_virtual_memory (GB)'] = max( infos['peak_virtual_memory (GB)'], _vm.used / 2**30) # If memory above threshold, raise error if _vm.percent > kwargs['max_virtual_memory']: raise MemoryError( f'Memory above threshold: {_vm.percent}% > {kwargs["max_virtual_memory"]}%' ) # Last refresh if verbose: pbar.refresh() # Check all chunks have been explored assert (np.alltrue(_status))
def gen(): for batch in ichunked(iter, batch_size): batch_df = prev_transformer.transform_iter(batch) for row in batch_df.itertuples(index=False): yield row._asdict()
def split( skip: Optional[int] = typer.Option( default=None, min=0, help="Number of games to skip from FILENAME before producing output.", ), games: Optional[int] = typer.Option( default=None, min=0, help="Number of games to extract from FILENAME."), chunked: int = typer.Option(default=1, min=1, help="Number of games exported to each file."), filename: Path = typer.Argument( ..., exists=True, file_okay=True, dir_okay=False, readable=True, help="Input file to be split into seperate PGN files.", ), output_folder: Path = typer.Argument( ..., exists=True, file_okay=False, dir_okay=True, writable=True, help="Output folder for PGN files.", ), dry_run: bool = typer.Option( default=False, help="Dry run the operation, thereby not producing any files."), ): """Streaming PGN splitter. Reads FILENAME and splits it into seperate PGN files output to OUTPUT_FOLDER. Output files are named 0.pgn, 1.pgn, ..., n.pgn. By default all games are exported, but the number of games being exported can be limited using --games, and the scanning can be offset using --start. """ # Find byte offsets for splits between PGNs seperators: Iterator[int] = find_byte_boundaries(filename) # Offset by skipping 'skip' PGNs if skip: consume(seperators, skip) if games is None: seperators = _tqdm_boundaries(filename, seperators) else: seperators = islice(seperators, games + 1) blocks = split_file(filename, seperators) if games: blocks = tqdm(blocks, total=games, unit="games") chunks = ichunked(blocks, chunked) for i, chunk in enumerate(chunks): if dry_run: continue with open(f"{output_folder}/{str(i)}.pgn", "wb") as pgn: for block in chunk: pgn.write(block)
def run(self): count = 0 # count number of test cases sent def testcases( reports: List[str] ) -> Generator[CaseEventType, None, None]: exceptions = [] for report in reports: try: yield from self.parse_func(report) except Exception as e: exceptions.append( Exception( "Failed to process a report file: {}".format( report), e)) if len(exceptions) > 0: # defer XML parsing exceptions so that we can send what we can send before we bail out raise Exception(exceptions) # generator that creates the payload incrementally def payload( cases: Generator[TestCase, None, None] ) -> Tuple[Dict[str, List], List[Exception]]: nonlocal count cs = [] exs = [] while True: try: cs.append(next(cases)) except StopIteration: break except Exception as ex: exs.append(ex) count += len(cs) return {"events": cs}, exs def send(payload: Dict[str, List]) -> None: res = client.request("post", "{}/events".format(session_id), payload=payload, compress=True) if res.status_code == HTTPStatus.NOT_FOUND: if session: build, _ = parse_session(session) click.echo(click.style( "Session {} was not found. Make sure to run `launchable record session --build {}` before `launchable record tests`" .format(session, build), 'yellow'), err=True) elif build_name: click.echo(click.style( "Build {} was not found. Make sure to run `launchable record build --name {}` before `launchable record tests`" .format(build_name, build_name), 'yellow'), err=True) res.raise_for_status() def recorded_result() -> Tuple[int, int, int, float]: test_count = 0 success_count = 0 fail_count = 0 duration = float(0) for tc in testcases(self.reports): test_count += 1 status = tc.get("status") if status == 0: fail_count += 1 elif status == 1: success_count += 1 duration += float(tc.get("duration") or 0) # sec return test_count, success_count, fail_count, duration / 60 # sec to min try: tc = testcases(self.reports) if report_paths: # diagnostics mode to just report test paths for t in tc: print(unparse_test_path(t['testPath'])) return exceptions = [] for chunk in ichunked(tc, post_chunk): p, es = payload(chunk) send(p) exceptions.extend(es) res = client.request("patch", "{}/close".format(session_id)) res.raise_for_status() if len(exceptions) > 0: raise Exception(exceptions) except Exception as e: if os.getenv(REPORT_ERROR_KEY): raise e else: traceback.print_exc() return if count == 0: if len(self.skipped_reports) != 0: click.echo( click.style( "{} test reports were skipped because they were created before `launchable record build` was run.\nMake sure to run tests after running `launchable record build`." .format(len(self.skipped_reports)), 'yellow')) return else: click.echo( click.style( "Looks like tests didn't run? If not, make sure the right files/directories are passed", 'yellow')) return file_count = len(self.reports) test_count, success_count, fail_count, duration = recorded_result() click.echo( "Launchable recorded tests for build {} (test session {}) to workspace {}/{} from {} files:\n" .format(build_name, test_session_id, org, workspace, file_count)) header = [ "Files found", "Tests found", "Tests passed", "Tests failed", "Total duration (min)" ] rows = [[ file_count, test_count, success_count, fail_count, "{:0.4f}".format(duration) ]] click.echo(tabulate(rows, header, tablefmt="github")) click.echo( "\nVisit https://app.launchableinc.com/organizations/{organization}/workspaces/{workspace}/test-sessions/{test_session_id} to view uploaded test results (or run `launchable inspect tests --test-session-id {test_session_id}`)" .format( organization=org, workspace=workspace, test_session_id=test_session_id, ))
def part2(text): return total(ichunked(chain.from_iterable(zip(*process(text))), 3))
def splitter(iterable: Generator, size: int) -> Iterator[Iterator]: return ichunked(iterable, size)
def index(self, generator): from ance.utils.util import pad_input_ids import torch import more_itertools import pyterrier as pt from ance.drivers.run_ann_data_gen import StreamInferenceDoc, load_model, GetProcessingFn import ance.drivers.run_ann_data_gen import pickle import os # monkey patch ANCE to use the same TQDM as PyTerrier ance.drivers.run_ann_data_gen.tqdm = pt.tqdm import os os.makedirs(self.index_path) config, tokenizer, model = _load_model(self.args, self.checkpoint_path) docid2docno = [] def gen_tokenize(): text_attr = self.text_attr kwargs = {} if self.num_docs is not None: kwargs['total'] = self.num_docs for doc in pt.tqdm(generator, desc="Indexing", unit="d", ** kwargs) if self.verbose else generator: contents = doc[text_attr] docid2docno.append(doc["docno"]) passage = tokenizer.encode( contents, add_special_tokens=True, max_length=self.args.max_seq_length, ) passage_len = min(len(passage), self.args.max_seq_length) input_id_b = pad_input_ids(passage, self.args.max_seq_length) yield passage_len, input_id_b segment = -1 shard_size = [] for gengen in more_itertools.ichunked(gen_tokenize(), self.segment_size): segment += 1 print("Segment %d" % segment) passage_embedding, passage_embedding2id = StreamInferenceDoc( self.args, model, GetProcessingFn(self.args, query=False), "passages", gengen, is_query_inference=False) dim = passage_embedding.shape[1] faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) faiss_file = os.path.join(self.index_path, str(segment) + ".faiss") lookup_file = os.path.join(self.index_path, str(segment) + ".docids.pkl") faiss.write_index(cpu_index, faiss_file) cpu_index = None passage_embedding = None with pt.io.autoopen(lookup_file, 'wb') as f: pickle.dump(passage_embedding2id, f) shard_size.append(len(passage_embedding2id)) passage_embedding2id = None with pt.io.autoopen(os.path.join(self.index_path, "shards.pkl"), 'wb') as f: pickle.dump(shard_size, f) pickle.dump(docid2docno, f) return self.index_path