def slice_before(iterable, predicate): iterable = more_itertools.peekable(iterable) while iterable: slice_ = _take_until(iterable, predicate) yield slice_ more_itertools.consume(slice_)
def sweep_entire_train(tensor_train: TensorTrain, inputs: Sequence[Input], **kwargs) -> Iterator[None]: """Each iteration is a sweep of the entire train""" sweep_iterator = sweep(tensor_train, inputs, **kwargs) while True: consume(sweep_iterator, len(tensor_train) - 1) yield
def _flush_body(self): """ Discard self.body but consume any generator such that any finalization can occur, such as is required by caching.tee_output(). """ consume(iter(self.body))
def routh_row(i_n_minus_2: Iterable[Basic], i_n_minus_1: Iterable[Basic]) -> \ Iterable[Basic]: "Computes the next row for a Routh matrix" pp_iter, pp_counter = counter_wrap(i_n_minus_2) p_iter, p_counter = counter_wrap(i_n_minus_1) a02, pp_iter = spy(pp_iter, 2) a1, p_iter = spy(p_iter, 1) for (a0, a2), (a1, a3) in zip(pairwise(pp_iter), pairwise(p_iter)): yield (a1 * a2 - a0 * a3) / a1 consume(map(consume, (pp_iter, p_iter))) if pp_counter() == 2 and p_counter() == 1: yield a02[1] return if not 0 <= pp_counter() - p_counter() <= 1 \ or p_counter() < 1: raise ValueError("pp row should be at most one item " "larger than p row and at least equal in size") #def routh_matrix(coeffs: Iterable[Basic]) -> # Iterable[List[Basic]]: # coeffs, coeffs_n = counter_wrap(coeffs) # i0, i1 = map(list, unzip(grouper(coeffs, 2, 0))) # i2: List[Basic] # for _ in range(coeffs_n() - 2): #def routh_recursive(coeffs: )
def unshare(ax: Axes, axis: tp.Literal['x', 'y', 'xy'] = 'xy', locator: tp.Union[tp.Type[Locator], Locator] = AutoLocator, formatter: tp.Union[tp.Type[Formatter], Formatter] = NullFormatter): for xy in axis: grouper = getattr(ax, f'get_shared_{xy}_axes')() axis = getattr(ax, f'{xy}axis') old_locator = axis.major.locator if old_locator.axis is axis: consume( map( partial(old_locator.__setattr__, 'axis'), map( attrgetter(f'{xy}axis'), islice( filter(partial(is_not, ax), grouper.get_siblings(ax)), 1)))) grouper.remove(ax) axis.major = Ticker() axis.set_major_locator( locator() if isinstance(locator, type) else locator) axis.set_major_formatter( formatter() if isinstance(formatter, type) else formatter)
def _flush_body(self): """ Discard self.body but consume any generator such that any finalization can occur, such as is required by caching.tee_output(). """ consume(iter(self.body))
def __call__(self, *args, **kwargs): consume(self.q.put(arg) for arg in args) with self.sema: while self.q.qsize() > 0: a = self.q.get() p = Process(target=self.wrap, args=(a,)) p.start()
def search_pythonlib(self): for dpath in self.dpaths: tmp_iter = os.walk(dpath) root, dnms, fnms = next(tmp_iter) t_diter = self._coroutine_print_dpylib(root, dnms) t_fiter = self._coroutine_print_pylib(root, fnms) consume(t_diter) consume(t_fiter)
def benchmark(constructor, name, *descriptors, n=1000): with _benchmark(name, *descriptors, "construction"): dataset = constructor() with _benchmark(name, *descriptors, "iteration"): more_itertools.consume(iter(dataset), n=n) print("=" * 80)
def run(cls, args): with timing.Stopwatch() as watch: all_swarms = _get_swarms(args) tmpl = "Loaded {n_swarms} swarms in {watch.elapsed}" msg = tmpl.format(n_swarms=len(all_swarms), watch=watch) print(msg) filtered_swarms = args.filter.matches(all_swarms) consume(map(print, sorted(filtered_swarms)))
def whole_return(self): ''' Returns the object created by the function and iterable. Used inplace of __next__ if the returned object should not be converted to an iterable. ''' from_func = self.func(*self.pass_args, **self.pass_kargs) consume(self.iterator) return from_func
def test_class_instantiation(): env = ToyLab() a1 = PPOAgent(env=env) a2 = HERSACAgent(env=env) a3 = GoalGANAgent(env=env, agent=a1) for a in [a1, a2, a3]: consume(trajectory(pi=a, env=env)) cb = EvaluateCallback(agent=a1, eval_env=env)
def teardown_class(cls): """Clean up sessions.""" super(cls, cls).teardown_class() consume( file.remove_p() for file in localDir.listdir() if file.basename().startswith( sessions.FileSession.SESSION_PREFIX ) )
def teardown_class(cls): """Clean up sessions.""" super(cls, cls).teardown_class() consume( file.remove_p() for file in localDir.listdir() if file.basename().startswith( sessions.FileSession.SESSION_PREFIX ) )
def main(): docs_dir = Path(__file__).parent / 'docs' for directory_name in DIRECTORIES: versions = filter( operator.methodcaller('is_dir'), (docs_dir / directory_name).iterdir() ) versions_to_backup = list(sorted(versions))[:-1] consume(map(backup, versions_to_backup))
def build_projects(self, projects: list[str], project_type: str, bundles: list[str]) -> None: args_collection: list[tuple[Settings, str, EPUBType, list[str]]] = [] kwargs_collection: list[dict[str, Any]] = [] for project in projects: args_collection.append( (self.settings, project, EPUBType(project_type), bundles)) kwargs_collection.append({}) consume( pool_run(build_epub, args_collection, kwargs_collection, "process", show_progress=True))
def continuation_lines(lines): while True: try: line_number, line = lines.peek() except StopIteration: line_number = -1 line = "" if not continuation_prompt_re.match(line.lstrip()): break # actually consume the item more_itertools.consume(lines, n=1) yield line_number, line
def process_forever(self, timeout=0.2): """Run an infinite loop, processing data from connections. This method repeatedly calls process_once. Arguments: timeout -- Parameter to pass to process_once. """ # This loop should specifically *not* be mutex-locked. # Otherwise no other thread would ever be able to change # the shared state of a Reactor object running this function. log.debug("process_forever(timeout=%s)", timeout) one = functools.partial(self.process_once, timeout=timeout) consume(repeatfunc(one))
class Meta(ABC): _meta_annotation = object() @classmethod def meta_attribute(cls, obj: _T) -> _T: return Annotated[obj, cls._meta_annotation] _meta_attributes: set[str] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._meta_attributes = { key for key, val in get_type_hints( cls, include_extras=True, localns={ cls.__name__: cls }).items() if get_origin(val) is Annotated and get_args(val)[1] is cls._meta_annotation } def __new__(cls, *args, **kwargs): meta_kwargs = { key: kwargs.pop(key) for key in cls._meta_attributes if key in kwargs } self = super().__new__(cls, *args, **kwargs) return self._meta_update(self, **meta_kwargs) @classmethod def _meta_update(cls, other: Meta, /, **kwargs): if isinstance(other, cls): consume( setattr(other, key, kwargs[key]) for key in other._meta_attributes if key in kwargs) return other
def resolve_source_from_match(cls, normalized_parent_source, pattern_match, child_symbol_pattern): it = iter(normalized_parent_source) more_itertools.consume(it, pattern_match.start()) modified_parent_source = "".join(it) lines = pavo_cristatus_split(modified_parent_source) # iterate until we find an indentation level that is less than what the child symbol starts at, construct source as we go resolved_source = "" expected_indent_level = cls.find_symbol_indent_in_parent_symbol_source( normalized_parent_source, child_symbol_pattern) for line in lines: if cls.get_indentation_level(line) < expected_indent_level: break resolved_source += line + "\n" return resolved_source
def count_neighbor_newlines(cls, lines: List[str], first: ast.AST, second: ast.AST) -> int: """ Count only logical newlines between two nodes, e.g. any node may consist of multiple lines, so you can't just take difference of `lineno` attribute, this value will be pointless :return: number of logical newlines (result will be 0 if second node is placed right after first) """ tokens_iter = cls._tokens_peekable_iter(lines) mitertools.consume(cls._take_until_node(tokens_iter, first)) return (cls._get_first_lineno(second) - cls._get_first_lineno(first)) - sum( 1 for tok in cls._take_until_node(tokens_iter, second) if tok.type == tokenize.NEWLINE)
def main() -> None: import argparse import os import ouster.pcap as pcap descr = """Example visualizer using the open3d library. Visualize either pcap data (specified using --pcap) or a running sensor (specified using --sensor). If no metadata file is specified, this will look for a file with the same name as the pcap with the '.json' extension, or query it directly from the sensor. Visualizing a running sensor requires the sensor to be configured and sending lidar data to the default UDP port (7502) on the host machine. """ parser = argparse.ArgumentParser(description=descr) parser.add_argument('--pause', action='store_true', help='start paused') parser.add_argument('--start', type=int, help='skip to frame number') parser.add_argument('--meta', metavar='PATH', help='path to metadata json') required = parser.add_argument_group('one of the following is required') group = required.add_mutually_exclusive_group(required=True) group.add_argument('--sensor', metavar='HOST', help='sensor hostname') group.add_argument('--pcap', metavar='PATH', help='path to pcap file') args = parser.parse_args() if args.sensor: scans = client.Scans.stream(args.sensor, metadata=args.meta) elif args.pcap: pcap_path = args.pcap metadata_path = args.meta or os.path.splitext(pcap_path)[0] + ".json" with open(metadata_path, 'r') as f: metadata = client.SensorInfo(f.read()) source = pcap.Pcap(pcap_path, metadata) scans = client.Scans(source) consume(scans, args.start or 0) try: viewer_3d(scans, paused=args.pause) except (KeyboardInterrupt, StopIteration): pass finally: scans.close()
def continuation_lines(lines, indent, prompt_length): while True: try: line_number, line = lines.peek() except StopIteration: line_number = -1 line = "" match = continuation_prompt_re.match(line) if not match or len( match.groupdict()["indent"]) - prompt_length + 5 != indent: break # actually consume the item more_itertools.consume(lines, n=1) yield line_number, line
def test_stratified(n:int, replace:bool): with open(fn, 'rb') as fp: # The function should finish quickly. def abort(*args): raise AssertionError('The function shouldn\'t take this long.') signal.signal(signal.SIGALRM, abort) signal.alarm(1) # We should get the right number of results. if n >= 0: observed = ilen(stratified_sample.sample(n, fp, replace = replace, give_up_at = 100)) assert observed == n else: with pytest.raises(ValueError): consume(stratified_sample.sample(n, fp)) signal.alarm(0)
def __next__(self): try: to_return = next(self.post_iterator) except StopIteration: try: from_func = self.func(*self.pass_args, **self.pass_kargs) consume(self.iterator) except self.empty_error: raise StopIteration try: self.post_iterator = iter(from_func) to_return = next(self.post_iterator) except TypeError: # create iterator from non iterator to_return = from_func return to_return
def main(task: str, use_gan: bool, do_train: bool, perform_eval: bool): env_params = {"visualize": not do_train} env_fn = PandaEnv if task == Task.REACH else PandaPickAndPlace env = env_fn(**env_params) agent_params = {"env": env} if use_gan: agent_params["experiment_name"] = "goalgan-her-sac" agent = HERSACAgent(**agent_params) if use_gan: agent = GoalGANAgent(env=env, agent=agent) if do_train: cbs = [EvaluateCallback(agent=agent, eval_env=env_fn(**env_params))] if perform_eval else [] agent.train(timesteps=50000, callbacks=cbs) else: while True: consume(trajectory(agent, env))
def _fill(self, candidates: Iterable[HostedInstr], util_info: BagValDict[ICaseString, InstrState], mem_busy: bool) -> InstrMovStatus: """Move candidate instructions between units. `self` is this unit sink. `candidates` are a list of candidate instructions. `util_info` is the unit utilization information. `mem_busy` is the memory busy flag. """ candid_iter = iter(candidates) mov_res = InstrMovStatus() more_itertools.consume( iter( lambda: self._mov_candidate(candid_iter, util_info, mem_busy or mov_res.mem_used, mov_res), False)) return mov_res
def _pretty_construct(graph: Graph): if not graph: return table = Table( show_header=True, header_style="bold magenta", ) consume(map(table.add_column, ('Subject', 'Predicate', 'Object'))) consume( itertools.starmap( table.add_row, [map( pretty_print_value, triple, ) for triple in graph], )) Console().print(table)
def generate_app() -> Typer: app = Typer() app.add_typer(sparql.app) app.add_typer(context.app) config = load_config() typer_instances = [ plugin.typer() for plugin in config['plugins'].values() if hasattr(plugin, 'typer') ] typer_instances = filter(bool, typer_instances) consume(map( app.add_typer, typer_instances, ), ) return app
def pack_projects(self, projects: list[str], project_type: str, compression: int) -> None: packaged_type_directory = Path(self.settings.packaged_epubs_directory, project_type) shutil.rmtree(packaged_type_directory, ignore_errors=True) packaged_type_directory.mkdir(parents=True, exist_ok=True) args_collection: list[tuple[Path, Path, int]] = [] kwargs_collection: list[dict[str, Any]] = [] for project in projects: expanded = Path(self.settings.expanded_epubs_directory, project_type, project) packaged = Path(packaged_type_directory, f"{project}.{project_type}.epub") args_collection.append((expanded, packaged, compression)) kwargs_collection.append({}) consume( pool_run(pack_epub, args_collection, kwargs_collection, "process", show_progress=True))
def _accept_instr(issue_rec: _IssueInfo, instr_categ: object, input_iter: Iterator[UnitModel], util_info: BagValDict[ICaseString, InstrState], accept_res: _AcceptStatus) -> None: """Try to accept the next instruction to an input unit. `issue_rec` is the issue record. `instr_categ` is the next instruction category. `input_iter` is an iterator over the input processing units to select from for issuing the instruction. `util_info` is the unit utilization information. `accept_res` is the instruction acceptance result. The function tries to find an appropriate unit to issue the instruction to. It then updates the utilization information. """ accept_res.accepted = False more_itertools.consume( iter( lambda: _accept_in_unit(input_iter, instr_categ, accept_res, util_info, issue_rec), True))
def _pretty_print_select_result(select_result: SelectResult): """Print a SPARQL query result in style.""" if not select_result: return table = Table( show_header=True, header_style="bold magenta", ) first_row = select_result[0] consume(map(table.add_column, first_row.keys())) consume( itertools.starmap( table.add_row, [map( pretty_print_value, row.values(), ) for row in select_result], )) Console().print(table)
def continuation_lines(lines, indent): options = tuple(take_while(lines, lambda x: x[1].strip())) newlines = tuple(take_while(lines, lambda x: not x[1].strip())) decorator_lines = tuple( take_while(lines, lambda x: x[1].lstrip().startswith("@"))) _, next_line = lines.peek((0, None)) if next_line is None: return if prompt_re.match(next_line): lines.prepend(*options, *newlines, *decorator_lines) raise RuntimeError("ipython prompt detected") yield from options yield from newlines yield from decorator_lines while True: newlines = tuple(take_while(lines, lambda x: not x[1].strip())) try: line_number, line = lines.peek() except StopIteration: break current_indent = len(line) - len(line.lstrip()) if current_indent <= indent: # put back the newlines, if any lines.prepend(*newlines) break yield from newlines # consume the line more_itertools.consume(lines, n=1) yield line_number, line
def distributors_by_film_and_ticket_type_report(self, start_date, end_date, distributor_id = "", film_id = "", exclude_complimentaries = False, new_page_for_each = NewPageForEach.nothing, detail_level = DetailLevel.showtime_by_ticket_type, multi_feature_revenue = MultiFeatureRevenue.full_revenue_per_film, site_id = DEFAULT_SITE_ID): dbfattr = self._report_workbook( Reports.distributors_by_film_and_ticket_type, dict( P193_From = start_date.strftime("%Y-%m-%d"), P193_To = end_date.strftime("%Y-%m-%d"), P194 = site_id, P199 = distributor_id, P198 = film_id, P195 = "Y" if exclude_complimentaries else "N", P196 = new_page_for_each.value, P197 = detail_level.value, P1251 = multi_feature_revenue.value ) ) param_keys = [ 'REPORT DATE RANGE', 'DISTRIBUTOR', 'FILM', 'MULTI/DOUBLE FEATURE', ] report_keys = [ 'SALES', 'REFUNDS', 'ADMITS', 'GROSS PRICE', 'NET PRICE', 'NET TOTAL', 'TAX TOTAL', 'GROSS TOTAL' ] engagements = dict() with dbfattr as wb: ws = wb["Sheet1"] report_name = ws.rows[0][1].value site_name = ws.rows[0][12].value param_key_offsets = self._get_offsets(ws.rows[3], param_keys) param_key_values = self._get_cell_values(ws.rows[4], param_key_offsets) report_key_offsets = self._get_offsets(ws.rows[9], report_keys) rows_it = more_itertools.peekable(ws.rows) more_itertools.consume(rows_it, 11) while rows_it.peek(None) is not None: distrib_row = next(rows_it) _log_row_type("DISTRIBUTOR", distrib_row) distrib_name, film_name = distrib_row[0].value.split(" - ", 1) end_film_value = "{0} total".format(film_name) end_distrib_value = "{0} total".format(distrib_name) while rows_it.peek()[0].value != end_film_value: site_screen_row = next(rows_it) _log_row_type("SITE & SCREEN", site_screen_row) site_screen_value = site_screen_row[0].value site_name, screen_name = site_screen_value.split(" - ", 1) end_site_value = "{0} total".format(site_screen_value) showtimes = [] while rows_it.peek()[0].value != end_site_value: showdate_row = next(rows_it) _log_row_type("SHOWDATE", showdate_row) showdate_value = showdate_row[0].value end_showdate_value = "{0} total".format(showdate_value) showdate_dt = dateutil.parser.parse(showdate_value) while rows_it.peek()[0].value != end_showdate_value: showtime_row = next(rows_it) _log_row_type("SHOWTIME", showtime_row) showtime_value = showtime_row[0].value full_showtime_value = "{0} {1}".format(showdate_value, showtime_value) end_showtime_value = "{0} total".format(showtime_value) full_showtime_dt = dateutil.parser.parse(full_showtime_value) tickets = dict() while rows_it.peek()[0].value != end_showtime_value: tt_row = next(rows_it) _log_row_type("TICKET TYPE", tt_row) tt_name = tt_row[0].value if tt_name is not None: tickets[tt_name] = dict( name = tt_name, sales = self._get_cell_value(tt_row, report_key_offsets, 'SALES'), refunds = self._get_cell_value(tt_row, report_key_offsets, 'REFUNDS'), admits = self._get_cell_value(tt_row, report_key_offsets, 'ADMITS'), gross_price = self._get_cell_value(tt_row, report_key_offsets, 'GROSS PRICE'), net_price = self._get_cell_value(tt_row, report_key_offsets, 'NET PRICE'), net_total = self._get_cell_value(tt_row, report_key_offsets, 'NET TOTAL'), tax_total = self._get_cell_value(tt_row, report_key_offsets, 'TAX TOTAL'), gross_total = self._get_cell_value(tt_row, report_key_offsets, 'GROSS TOTAL') ) end_showtime_row = next(rows_it) _log_row_type("END SHOWTIME", end_showtime_row) showtimes.append(dict( screen_name = screen_name, showtime = full_showtime_dt, tickets = tickets )) end_showdate_row = next(rows_it) _log_row_type("END SHOWDATE", end_showdate_row) engagement_key = (site_name, film_name) try: engagement = engagements[engagement_key] except KeyError: engagement = engagements[engagement_key] = dict( site_name = site_name, film_name = film_name, distributor_name = distrib_name, showtimes = showtimes, ) end_site_row = next(rows_it) _log_row_type("END SITE", end_site_row) end_film_row = next(rows_it) _log_row_type("END FILM", end_film_row) if rows_it.peek()[0].value == end_distrib_value: end_distrib_row = next(rows_it) _log_row_type("END DISTRIBUTOR", end_distrib_row) # consume any blank rows while True: post_distrib_row = rows_it.peek(False) if post_distrib_row != False: if not filter(None, [c.value for c in post_distrib_row]): blank_row = next(rows_it) _log_row_type("BLANK", blank_row) continue break return engagements
def test_negative_consume(self): """Check that negative consumsion throws an error""" r = (x for x in range(10)) self.assertRaises(ValueError, lambda: mi.consume(r, -1))
def test_sanity(self): """Test basic functionality""" r = (x for x in range(10)) mi.consume(r, 3) self.assertEqual(3, next(r))
def test_null_consume(self): """Check the null case""" r = (x for x in range(10)) mi.consume(r, 0) self.assertEqual(0, next(r))
def test_total_consume(self): """Check that iterator is totally consumed by default""" r = (x for x in range(10)) mi.consume(r) self.assertRaises(StopIteration, lambda: next(r))