def make_progress() -> Progress: _time = 0.0 def fake_time(): nonlocal _time try: return _time finally: _time += 1 console = Console( file=io.StringIO(), force_terminal=True, color_system="truecolor", width=80, legacy_windows=False, ) progress = Progress(console=console, get_time=fake_time, auto_refresh=False) task1 = progress.add_task("foo") task2 = progress.add_task("bar", total=30) progress.advance(task2, 16) task3 = progress.add_task("baz", visible=False) task4 = progress.add_task("egg") progress.remove_task(task4) task4 = progress.add_task("foo2", completed=50, start=False) progress.stop_task(task4) progress.start_task(task4) progress.update( task4, total=200, advance=50, completed=200, visible=True, refresh=True ) progress.stop_task(task4) return progress
def make_progress() -> Progress: console = Console(file=io.StringIO(), force_terminal=True) progress = Progress(console=console) task1 = progress.add_task("foo") task2 = progress.add_task("bar", 30) progress.advance(task2, 16) task3 = progress.add_task("baz", visible=False) task4 = progress.add_task("egg") progress.remove_task(task4) task4 = progress.add_task("foo2", completed=50, start=False) progress.start_task(task4) progress.update( task4, total=200, advance=50, completed=200, visible=True, refresh=True ) return progress
def _bar( self, progress: Progress, description: str, total: Optional[float], ) -> Iterator[ProgressBar]: if total is None: # Indeterminate progress bar taskid = progress.add_task(description, start=False) else: taskid = progress.add_task(description, total=total) self._update_live() try: yield ProgressBar(progress, taskid) finally: progress.remove_task(taskid) self._update_live()
def run(): """ Run the Console Interface for Unsilence :return: None """ sys.tracebacklimit = 0 args = parse_arguments() console = Console() if args.debug: sys.tracebacklimit = 1000 if args.output_file.exists() and not args.non_interactive_mode: if not choice_dialog( console, "File already exists. Overwrite?", default=False): return args_dict = vars(args) argument_list_for_silence_detect = [ "silence_level", "silence_time_threshold", "short_interval_threshold", "stretch_time" ] argument_dict_for_silence_detect = { key: args_dict[key] for key in argument_list_for_silence_detect if key in args_dict.keys() } argument_list_for_renderer = [ "audio_only", "audible_speed", "silent_speed", "audible_volume", "silent_volume", "drop_corrupted_intervals", "threads" ] argument_dict_for_renderer = { key: args_dict[key] for key in argument_list_for_renderer if key in args_dict.keys() } progress = Progress() continual = Unsilence(args.input_file) with progress: def update_task(current_task): def handler(current_val, total): progress.update(current_task, total=total, completed=current_val) return handler silence_detect_task = progress.add_task("Calculating Intervals...", total=1) continual.detect_silence( on_silence_detect_progress_update=update_task(silence_detect_task), **argument_dict_for_silence_detect) progress.stop() progress.remove_task(silence_detect_task) print() estimated_time = continual.estimate_time(args.audible_speed, args.silent_speed) console.print(pretty_time_estimate(estimated_time)) print() if not args.non_interactive_mode: if not choice_dialog( console, "Continue with these options?", default=True): return progress.start() rendering_task = progress.add_task("Rendering Intervals...", total=1) concat_task = progress.add_task("Combining Intervals...", total=1) continual.render_media( args.output_file, on_render_progress_update=update_task(rendering_task), on_concat_progress_update=update_task(concat_task), **argument_dict_for_renderer) progress.stop() console.print("\n[green]Finished![/green] :tada:") print()
class RichProgressBar(ProgressBarBase): def __init__(self): super().__init__() self.pg = Progress( "[progress.description]{task.description}", BarColumn(), "[progress.percentage]{task.completed}/{task.total}", BarDictColumn(), transient=True, expand=True, refresh_per_second=2, ) self._train_id = self._val_id = self._test_id = None def disable(self): self.pg.disable = True def enable(self): self.pg.disable = False def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self.pg.start() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) bardic = pl_module.get_progress_bar_dict() self.pg.update(self._test_id, completed=self.test_batch_idx, **bardic) if self.test_batch_idx >= self.total_test_batches: self.pg.stop_task(self._test_id) def on_test_epoch_start(self, trainer, pl_module) -> None: super().on_test_epoch_start(trainer, pl_module) if self._test_id is not None: self.pg.remove_task(self._test_id) self._test_id = self.pg.add_task("epoch T%03d" % trainer.current_epoch, total=self.total_test_batches) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) bardic = pl_module.get_progress_bar_dict() self.pg.update(self._train_id, completed=self.train_batch_idx, **bardic) if self.train_batch_idx >= self.total_train_batches: self.pg.stop_task(self._train_id) def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) if self._train_id is not None: self.pg.remove_task(self._train_id) if self._val_id is not None: self.pg.remove_task(self._val_id) self._train_id = self.pg.add_task("epoch T%03d" % trainer.current_epoch, total=self.total_train_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) bardic = pl_module.get_progress_bar_dict() self.pg.update(self._val_id, completed=self.val_batch_idx, **bardic) if self.val_batch_idx >= self.total_val_batches: self.pg.stop_task(self._val_id) def on_validation_epoch_start(self, trainer, pl_module) -> None: super().on_validation_epoch_start(trainer, pl_module) self._val_id = self.pg.add_task("epoch V%03d" % trainer.current_epoch, total=self.total_val_batches) def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) self.pg.stop() def print(self, *args, **kwargs): self.pg.console.print(*args, **kwargs)
class RichOutput(AbstractChecksecOutput): def __init__(self, libc_detected: bool = False): """Init Rich Console and Table""" super().__init__(libc_detected) # init ELF table self.table_elf = Table(title="Checksec Results: ELF", expand=True) self.table_elf.add_column("File", justify="left", header_style="") self.table_elf.add_column("NX", justify="center") self.table_elf.add_column("PIE", justify="center") self.table_elf.add_column("Canary", justify="center") self.table_elf.add_column("Relro", justify="center") self.table_elf.add_column("RPATH", justify="center") self.table_elf.add_column("RUNPATH", justify="center") self.table_elf.add_column("Symbols", justify="center") if self._libc_detected: self.table_elf.add_column("FORTIFY", justify="center") self.table_elf.add_column("Fortified", justify="center") self.table_elf.add_column("Fortifiable", justify="center") self.table_elf.add_column("Fortify Score", justify="center") # init PE table self.table_pe = Table(title="Checksec Results: PE", expand=True) self.table_pe.add_column("File", justify="left", header_style="") self.table_pe.add_column("NX", justify="center") self.table_pe.add_column("Canary", justify="center") self.table_pe.add_column("ASLR", justify="center") self.table_pe.add_column("Dynamic Base", justify="center") self.table_pe.add_column("High Entropy VA", justify="center") self.table_pe.add_column("SEH", justify="center") self.table_pe.add_column("SafeSEH", justify="center") self.table_pe.add_column("Force Integrity", justify="center") self.table_pe.add_column("Control Flow Guard", justify="center") self.table_pe.add_column("Isolation", justify="center") self.table_pe.add_column("Authenticode", justify="center") # init console self.console = Console() # build progress bar self.process_bar = Progress( TextColumn("[bold blue]Processing...", justify="left"), BarColumn(bar_width=None), "{task.completed}/{task.total}", "•", "[progress.percentage]{task.percentage:>3.1f}%", console=self.console, ) self.display_res_bar = Progress( BarColumn(bar_width=None), TextColumn("[bold blue]{task.description}", justify="center"), BarColumn(bar_width=None), console=self.console, transient=True, ) self.enumerate_bar = Progress( TextColumn("[bold blue]Enumerating...", justify="center"), BarColumn(bar_width=None), console=self.console, transient=True, ) self.process_task_id = None def __exit__(self, exc_type, exc_val, exc_tb): # cleanup the Rich progress bars if self.enumerate_bar is not None: self.enumerate_bar.stop() if self.process_bar is not None: self.process_bar.stop() if self.display_res_bar is not None: self.display_res_bar.stop() def enumerating_tasks_start(self): # start progress bar self.enumerate_bar.start() self.enumerate_bar.add_task("Enumerating", start=False) def enumerating_tasks_stop(self, total: int): super().enumerating_tasks_stop(total) self.enumerate_bar.stop() self.enumerate_bar = None def processing_tasks_start(self): # init progress bar self.process_bar.start() self.process_task_id = self.process_bar.add_task("Checking", total=self.total) def add_checksec_result(self, filepath: Path, checksec: Union[ELFChecksecData, PEChecksecData]): logging.debug("result for %s: %s", filepath, checksec) if isinstance(checksec, ELFChecksecData): row_res: List[str] = [] # display results if not checksec.nx: nx_res = "[red]No" else: nx_res = "[green]Yes" row_res.append(nx_res) pie = checksec.pie if pie == PIEType.No: pie_res = f"[red]{pie.name}" elif pie == PIEType.DSO: pie_res = f"[yellow]{pie.name}" else: pie_res = "[green]Yes" row_res.append(pie_res) if not checksec.canary: canary_res = "[red]No" else: canary_res = "[green]Yes" row_res.append(canary_res) relro = checksec.relro if relro == RelroType.No: relro_res = f"[red]{relro.name}" elif relro == RelroType.Partial: relro_res = f"[yellow]{relro.name}" else: relro_res = f"[green]{relro.name}" row_res.append(relro_res) if checksec.rpath: rpath_res = "[red]Yes" else: rpath_res = "[green]No" row_res.append(rpath_res) if checksec.runpath: runpath_res = "[red]Yes" else: runpath_res = "[green]No" row_res.append(runpath_res) if checksec.symbols: symbols_res = "[red]Yes" else: symbols_res = "[green]No" row_res.append(symbols_res) # fortify results depend on having a Libc available if self._libc_detected: fortified_count = checksec.fortified if checksec.fortify_source: fortify_source_res = "[green]Yes" else: fortify_source_res = "[red]No" row_res.append(fortify_source_res) if fortified_count == 0: fortified_res = "[red]No" else: fortified_res = f"[green]{fortified_count}" row_res.append(fortified_res) fortifiable_count = checksec.fortifiable if fortified_count == 0: fortifiable_res = "[red]No" else: fortifiable_res = f"[green]{fortifiable_count}" row_res.append(fortifiable_res) if checksec.fortify_score == 0: fortified_score_res = f"[red]{checksec.fortify_score}" elif checksec.fortify_score == 100: fortified_score_res = f"[green]{checksec.fortify_score}" else: fortified_score_res = f"[yellow]{checksec.fortify_score}" row_res.append(fortified_score_res) self.table_elf.add_row(str(filepath), *row_res) elif isinstance(checksec, PEChecksecData): if not checksec.nx: nx_res = "[red]No" else: nx_res = "[green]Yes" if not checksec.canary: canary_res = "[red]No" else: canary_res = "[green]Yes" if not checksec.aslr: aslr_res = "[red]No" else: aslr_res = "[green]Yes" if not checksec.dynamic_base: dynamic_base_res = "[red]No" else: dynamic_base_res = "[green]Yes" # this is only relevant is binary is 64 bits if checksec.machine == MACHINE_TYPES.AMD64: if not checksec.high_entropy_va: entropy_va_res = "[red]No" else: entropy_va_res = "[green]Yes" else: entropy_va_res = "/" if not checksec.seh: seh_res = "[red]No" else: seh_res = "[green]Yes" # only relevant if 32 bits if checksec.machine == MACHINE_TYPES.I386: if not checksec.safe_seh: safe_seh_res = "[red]No" else: safe_seh_res = "[green]Yes" else: safe_seh_res = "/" if not checksec.authenticode: auth_res = "[red]No" else: auth_res = "[green]Yes" if not checksec.force_integrity: force_integrity_res = "[red]No" else: force_integrity_res = "[green]Yes" if not checksec.guard_cf: guard_cf_res = "[red]No" else: guard_cf_res = "[green]Yes" if not checksec.isolation: isolation_res = "[red]No" else: isolation_res = "[green]Yes" self.table_pe.add_row( str(filepath), nx_res, canary_res, aslr_res, dynamic_base_res, entropy_va_res, seh_res, safe_seh_res, force_integrity_res, guard_cf_res, isolation_res, auth_res, ) else: raise NotImplementedError def checksec_result_end(self): """Update progress bar""" self.process_bar.update(self.process_task_id, advance=1) def print(self): self.process_bar.stop() self.process_bar = None if self.table_elf.row_count > 0: with self.display_res_bar: task_id = self.display_res_bar.add_task( "Displaying Results: ELF ...", start=False) self.console.print(self.table_elf) self.display_res_bar.remove_task(task_id) if self.table_pe.row_count > 0: with self.display_res_bar: task_id = self.display_res_bar.add_task( "Displaying Results: PE ...", start=False) self.console.print(self.table_pe) self.display_res_bar.remove_task(task_id) self.display_res_bar.stop() self.display_res_bar = None
class BatchExecutor: transient_progress: bool = False run_builder_job: Callable[..., Awaitable[str]] = staticmethod( # type: ignore start_image_build) def __init__( self, console: Console, flow: RunningBatchFlow, bake: Bake, attempt: Attempt, client: RetryReadNeuroClient, storage: AttemptStorage, bake_storage: BakeStorage, project_storage: ProjectStorage, *, polling_timeout: Optional[float] = 1, project_role: Optional[str] = None, ) -> None: self._progress = Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TextColumn("[progress.remaining]{task.elapsed:.0f} sec"), console=console, auto_refresh=False, transient=self.transient_progress, redirect_stderr=True, ) self._top_flow = flow self._bake = bake self._attempt = attempt self._client = client self._storage = storage self._project_storage = project_storage self._bake_storage = bake_storage self._graphs: Graph[RunningBatchBase[BaseBatchContext]] = Graph( flow.graph, flow) self._tasks_mgr = BakeTasksManager() self._is_cancelling = False self._project_role = project_role self._is_projet_role_created = False self._run_builder_job = self.run_builder_job # A note about default value: # AS: I have no idea what timeout is better; # too short value bombards servers, # too long timeout makes the waiting longer than expected # The missing events subsystem would be great for this task :) self._polling_timeout = polling_timeout self._bars: Dict[FullID, TaskID] = {} @classmethod @asynccontextmanager async def create( cls, console: Console, bake_id: str, client: Client, storage: Storage, *, polling_timeout: Optional[float] = 1, project_role: Optional[str] = None, ) -> AsyncIterator["BatchExecutor"]: storage = storage.with_retry_read() console.log("Fetch bake data") bake_storage = storage.bake(id=bake_id) bake = await bake_storage.get() if bake.last_attempt: attempt = bake.last_attempt attempt_storage = bake_storage.attempt(id=bake.last_attempt.id) else: attempt_storage = bake_storage.last_attempt() attempt = await attempt_storage.get() console.log("Fetch configs metadata") flow = await get_running_flow(bake, client, bake_storage, attempt.configs_meta) console.log(f"Execute attempt #{attempt.number}") ret = cls( console=console, flow=flow, bake=bake, attempt=attempt, client=RetryReadNeuroClient(client), storage=attempt_storage, bake_storage=bake_storage, project_storage=storage.project(id=bake.project_id), polling_timeout=polling_timeout, project_role=project_role, ) ret._start() try: yield ret finally: ret._stop() def _start(self) -> None: pass def _stop(self) -> None: pass async def _refresh_attempt(self) -> None: self._attempt = await self._storage.get() def _only_completed_needs( self, needs: Mapping[str, ast.NeedsLevel]) -> AbstractSet[str]: return { task_id for task_id, level in needs.items() if level == ast.NeedsLevel.COMPLETED } # Exact task/action contexts helpers async def _get_meta(self, full_id: FullID) -> TaskMeta: prefix, tid = full_id[:-1], full_id[-1] flow = self._graphs.get_meta(full_id) needs = self._tasks_mgr.build_needs( prefix, self._only_completed_needs(flow.graph[tid])) state = self._tasks_mgr.build_state(prefix, await flow.state_from(tid)) return await flow.get_meta(tid, needs=needs, state=state) async def _get_task(self, full_id: FullID) -> Task: prefix, tid = full_id[:-1], full_id[-1] flow = self._graphs.get_meta(full_id) needs = self._tasks_mgr.build_needs( prefix, self._only_completed_needs(flow.graph[tid])) state = self._tasks_mgr.build_state(prefix, await flow.state_from(tid)) return await flow.get_task(prefix, tid, needs=needs, state=state) async def _get_local(self, full_id: FullID) -> LocalTask: prefix, tid = full_id[:-1], full_id[-1] flow = self._graphs.get_meta(full_id) needs = self._tasks_mgr.build_needs( prefix, self._only_completed_needs(flow.graph[tid])) return await flow.get_local(tid, needs=needs) async def _get_action(self, full_id: FullID) -> RunningBatchActionFlow: prefix, tid = full_id[:-1], full_id[-1] flow = self._graphs.get_meta(full_id) needs = self._tasks_mgr.build_needs( prefix, self._only_completed_needs(flow.graph[tid])) return await flow.get_action(tid, needs=needs) async def _get_image(self, bake_image: BakeImage) -> Optional[ImageCtx]: actions = [] full_id_to_image: Dict[FullID, ImageCtx] = {} for yaml_def in bake_image.yaml_defs: *prefix, image_id = yaml_def actions.append(".".join(prefix)) try: flow = self._graphs.get_graph_data(tuple(prefix)) except KeyError: pass else: full_id_to_image[yaml_def] = flow.images[image_id] image_ctx = None for _it in full_id_to_image.values(): image_ctx = _it # Return any context break if not all(image_ctx == it for it in full_id_to_image.values()): locations_str = "\n".join(f" - {'.'.join(full_id)}" for full_id in full_id_to_image.keys()) self._progress.log( f"[red]Warning:[/red] definitions for image with" f" ref [b]{bake_image.ref}[/b] differ:\n" f"{locations_str}") if len(full_id_to_image.values()) == 0: actions_str = ", ".join(f"'{it}'" for it in actions) self._progress.log( f"[red]Warning:[/red] image with ref " f"[b]{bake_image.ref}[/b] is referred " + (f"before action {actions_str} " "that holds its definition was able to run" if len(actions) == 1 else f"before actions {actions_str} " "that hold its definition were able to run")) return image_ctx # Graph helpers async def _embed_action(self, full_id: FullID) -> None: action = await self._get_action(full_id) self._graphs.embed(full_id, action.graph, action) task_id = self._progress.add_task( ".".join(full_id), completed=0, total=self._graphs.get_size(full_id) * 3, ) self._bars[full_id] = task_id def _advance_progress_bar(self, old_task: Optional[StorageTask], task: StorageTask) -> None: def _state_to_progress(status: TaskStatus) -> int: if status.is_pending: return 1 if status.is_running: return 2 return 3 if old_task: advance = _state_to_progress(task.status) - _state_to_progress( old_task.status) else: advance = _state_to_progress(task.status) progress_bar_id = self._bars[task.yaml_id[:-1]] self._progress.update(progress_bar_id, advance=advance, refresh=True) def _update_graph_coloring(self, task: StorageTask) -> None: if task.status.is_running or task.status.is_finished: self._graphs.mark_running(task.yaml_id) if task.status.is_finished: self._graphs.mark_done(task.yaml_id) async def _log_task_status_change(self, task: StorageTask) -> None: flow = self._graphs.get_meta(task.yaml_id) in_flow_id = task.yaml_id[-1] if await flow.is_action(in_flow_id): log_args = ["Action"] elif await flow.is_local(in_flow_id): log_args = ["Local action"] else: log_args = ["Task"] log_args.append(fmt_id(task.yaml_id)) if task.raw_id: log_args.append(fmt_raw_id(task.raw_id)) log_args += ["is", fmt_status(task.status)] if task.status.is_finished: if task.outputs: log_args += ["with following outputs:"] else: log_args += ["(no outputs)"] self._progress.log(*log_args) if task.outputs: for key, value in task.outputs.items(): self._progress.log(f" {key}: {value}") async def _create_task( self, yaml_id: FullID, raw_id: Optional[str], status: Union[TaskStatusItem, TaskStatus], outputs: Optional[Mapping[str, str]] = None, state: Optional[Mapping[str, str]] = None, ) -> StorageTask: task = await self._storage.create_task(yaml_id, raw_id, status, outputs, state) self._advance_progress_bar(None, task) self._update_graph_coloring(task) await self._log_task_status_change(task) self._tasks_mgr.update(task) return task async def _update_task( self, yaml_id: FullID, *, outputs: Union[Optional[Mapping[str, str]], Type[_Unset]] = _Unset, state: Union[Optional[Mapping[str, str]], Type[_Unset]] = _Unset, new_status: Optional[Union[TaskStatusItem, TaskStatus]] = None, ) -> StorageTask: old_task = self._tasks_mgr.tasks[yaml_id] task = await self._storage.task(id=old_task.id).update( outputs=outputs, state=state, new_status=new_status, ) self._advance_progress_bar(old_task, task) self._update_graph_coloring(task) await self._log_task_status_change(task) self._tasks_mgr.update(task) return task # New tasks processers async def _skip_task(self, full_id: FullID) -> None: log.debug(f"BatchExecutor: skipping task {full_id}") await self._create_task( yaml_id=full_id, status=TaskStatus.SKIPPED, raw_id=None, outputs={}, state={}, ) async def _process_local(self, full_id: FullID) -> None: raise ValueError("Processing of local actions is not supported") async def _process_action(self, full_id: FullID) -> None: log.debug(f"BatchExecutor: processing action {full_id}") await self._create_task( yaml_id=full_id, status=TaskStatus.PENDING, raw_id=None, ) await self._embed_action(full_id) async def _process_task(self, full_id: FullID) -> None: log.debug(f"BatchExecutor: processing task {full_id}") task = await self._get_task(full_id) # Check is task fits max_parallel for n in range(1, len(full_id) + 1): node = full_id[:n] prefix = node[:-1] if self._tasks_mgr.count_unfinished( prefix) >= task.strategy.max_parallel: return cache_strategy = task.cache.strategy storage_task: Optional[StorageTask] = None if cache_strategy == ast.CacheStrategy.DEFAULT: try: cache_entry = await self._project_storage.cache_entry( task_id=full_id, batch=self._bake.batch, key=task.caching_key).get() except ResourceNotFound: pass else: now = datetime.now(timezone.utc) eol = cache_entry.created_at + timedelta(task.cache.life_span) if now < eol: storage_task = await self._create_task( yaml_id=full_id, raw_id=cache_entry.raw_id, status=TaskStatus.CACHED, outputs=cache_entry.outputs, state=cache_entry.state, ) if storage_task is None: await self._start_task(full_id, task) async def _load_previous_run(self) -> None: log.debug(f"BatchExecutor: loading previous run") # Loads tasks that previous executor run processed. # Do not handles fail fast option. root_id = () task_id = self._progress.add_task( f"<{self._bake.batch}>", completed=0, total=self._graphs.get_size(root_id) * 3, ) self._bars[root_id] = task_id tasks = { task.yaml_id: task async for task in self._storage.list_tasks() } def _set_task_info(task: StorageTask) -> None: self._update_graph_coloring(task) self._advance_progress_bar(None, task) self._tasks_mgr.update(task) # Rebuild data in graph and task_mgr while tasks.keys() != self._tasks_mgr.tasks.keys(): log.debug(f"BatchExecutor: rebuilding data...") for full_id, ctx in self._graphs.get_ready_with_meta(): if full_id not in tasks or full_id in self._tasks_mgr.tasks: continue task = tasks[full_id] if await ctx.is_action(full_id[-1]): await self._embed_action(full_id) if task.status.is_pending: _set_task_info(task) else: _set_task_info(task) for full_id in self._graphs.get_ready_to_mark_running_embeds(): task = tasks[full_id] if task.status.is_finished: self._graphs.mark_running(full_id) elif task.status.is_running: _set_task_info(task) for full_id in self._graphs.get_ready_to_mark_done_embeds(): task = tasks[full_id] if task.status.is_finished: _set_task_info(task) async def _should_continue(self) -> bool: return self._graphs.is_active async def run(self) -> TaskStatus: with self._progress: try: result = await self._run() except (KeyboardInterrupt, asyncio.CancelledError): await self._cancel_unfinished() result = TaskStatus.CANCELLED except Exception as exp: await self.log_error(exp) await self._cancel_unfinished() result = TaskStatus.FAILED await self._finish_run(result) return result async def log_error(self, exp: Exception) -> None: if isinstance(exp, EvalError): self._progress.log(f"[red][b]ERROR:[/b] {exp}[/red]") else: # Looks like this is some bug in our code, so we should print stacktrace # for easier debug self._progress.console.print_exception(show_locals=True) self._progress.log( "[red][b]ERROR:[/b] Some unknown error happened. Please report " "an issue to https://github.com/neuro-inc/neuro-flow/issues/new " "with traceback printed above.[/red]") async def _run(self) -> TaskStatus: await self._load_previous_run() job_id = os.environ.get("NEURO_JOB_ID") if job_id: await self._storage.update(executor_id=job_id, ) if self._attempt.result.is_finished: # If attempt is already terminated, just clean up tasks # and exit await self._cancel_unfinished() return self._attempt.result await self._storage.update(result=TaskStatus.RUNNING, ) while await self._should_continue(): _mgr = self._tasks_mgr def _fmt_debug(task: StorageTask) -> str: return ".".join(task.yaml_id) + f" - {task.status}" log.debug( f"BatchExecutor: main loop start:\n" f"tasks={', '.join(_fmt_debug(it) for it in _mgr.tasks.values())}\n" ) for full_id, flow in self._graphs.get_ready_with_meta(): tid = full_id[-1] if full_id in self._tasks_mgr.tasks: continue # Already started meta = await self._get_meta(full_id) if not meta.enable or (self._is_cancelling and meta.enable is not AlwaysT()): # Make task started and immediatelly skipped await self._skip_task(full_id) continue if await flow.is_local(tid): await self._process_local(full_id) elif await flow.is_action(tid): await self._process_action(full_id) else: await self._process_task(full_id) ok = await self._process_started() await self._process_running_builds() # Check for cancellation if not self._is_cancelling: await self._refresh_attempt() if not ok or self._attempt.result == TaskStatus.CANCELLED: await self._cancel_unfinished() await asyncio.sleep(self._polling_timeout or 0) return self._accumulate_result() async def _finish_run(self, attempt_status: TaskStatus) -> None: if not self._attempt.result.is_finished: await self._storage.update(result=attempt_status) self._progress.print( Panel( f"[b]Attempt #{self._attempt.number}[/b] {fmt_status(attempt_status)}" ), justify="center", ) if attempt_status != TaskStatus.SUCCEEDED: self._progress.log( "[blue b]Hint:[/blue b] you can restart bake starting " "from first failed task " "by the following command:\n" f"[b]neuro-flow restart {self._bake.id}[/b]") async def _cancel_unfinished(self) -> None: self._is_cancelling = True for task, raw_id in self._tasks_mgr.list_unfinished_raw_tasks(): task_ctx = await self._get_task(task.yaml_id) if task_ctx.enable is not AlwaysT(): self._progress.log( f"Task {fmt_id(task.yaml_id)} is being killed") await self._client.job_kill(raw_id) async for image in self._bake_storage.list_bake_images(): if image.status == ImageStatus.BUILDING: assert image.builder_job_id await self._client.job_kill(image.builder_job_id) async def _store_to_cache(self, task: StorageTask) -> None: log.debug(f"BatchExecutor: storing to cache {task.yaml_id}") task_ctx = await self._get_task(task.yaml_id) cache_strategy = task_ctx.cache.strategy if (cache_strategy == ast.CacheStrategy.NONE or task.status != TaskStatus.SUCCEEDED or task.raw_id is None or task.outputs is None or task.state is None): return await self._project_storage.create_cache_entry( key=task_ctx.caching_key, batch=self._bake.batch, task_id=task.yaml_id, raw_id=task.raw_id, outputs=task.outputs, state=task.state, ) async def _process_started(self) -> bool: log.debug(f"BatchExecutor: processing started") # Process tasks for task, raw_id in self._tasks_mgr.list_unfinished_raw_tasks(): assert task.raw_id, "list_unfinished_raw_tasks should list raw tasks" job_descr = await self._client.job_status(task.raw_id) log.debug( f"BatchExecutor: got description {job_descr} for task {task.yaml_id}" ) if (job_descr.status in {JobStatus.RUNNING, JobStatus.SUCCEEDED} and task.status.is_pending): assert (job_descr.history.started_at ), "RUNNING jobs should have started_at" task = await self._update_task( task.yaml_id, new_status=TaskStatusItem( when=job_descr.history.started_at, status=TaskStatus.RUNNING, ), ) if job_descr.status in TERMINATED_JOB_STATUSES: log.debug( f"BatchExecutor: processing logs for task {task.yaml_id}") async with CmdProcessor() as proc: async for chunk in self._client.job_logs(raw_id): async for line in proc.feed_chunk(chunk): pass async for line in proc.feed_eof(): pass log.debug( f"BatchExecutor: finished processing logs for task {task.yaml_id}" ) assert (job_descr.history.finished_at ), "TERMINATED jobs should have finished_at" task = await self._update_task( task.yaml_id, new_status=TaskStatusItem( when=job_descr.history.finished_at, status=TaskStatus(job_descr.status), ), outputs=proc.outputs, state=proc.states, ) await self._store_to_cache(task) task_meta = await self._get_meta(task.yaml_id) if task.status != TaskStatus.SUCCEEDED and task_meta.strategy.fail_fast: return False # Process batch actions for full_id in self._graphs.get_ready_to_mark_running_embeds(): log.debug(f"BatchExecutor: marking action {full_id} as ready") await self._update_task(full_id, new_status=TaskStatus.RUNNING) for full_id in self._graphs.get_ready_to_mark_done_embeds(): log.debug(f"BatchExecutor: marking action {full_id} as done") # done action, make it finished ctx = await self._get_action(full_id) results = self._tasks_mgr.build_needs(full_id, ctx.graph.keys()) res_ctx = await ctx.calc_outputs(results) task = await self._update_task( full_id, new_status=TaskStatus(res_ctx.result), outputs=res_ctx.outputs, state={}, ) task_id = self._bars[full_id] self._progress.remove_task(task_id) task_meta = await self._get_meta(full_id) if task.status != TaskStatus.SUCCEEDED and task_meta.strategy.fail_fast: return False return True def _accumulate_result(self) -> TaskStatus: for task in self._tasks_mgr.tasks.values(): if task.status == TaskStatus.CANCELLED: return TaskStatus.CANCELLED elif task.status == TaskStatus.FAILED: return TaskStatus.FAILED return TaskStatus.SUCCEEDED async def _is_image_in_registry(self, remote_image: RemoteImage) -> bool: try: await self._client.image_tag_info(remote_image) except ResourceNotFound: return False return True async def _start_task(self, full_id: FullID, task: Task) -> Optional[StorageTask]: log.debug( f"BatchExecutor: checking should we build image for {full_id}") remote_image = self._client.parse_remote_image(task.image) log.debug(f"BatchExecutor: image name is {remote_image}") if remote_image.cluster_name is None: # Not a neuro registry image return await self._run_task(full_id, task) image_storage = self._bake_storage.bake_image(ref=task.image) try: bake_image = await image_storage.get() except ResourceNotFound: # Not defined in the bake return await self._run_task(full_id, task) image_ctx = await self._get_image(bake_image) if image_ctx is None: # The image referred before definition return await self._run_task(full_id, task) if not image_ctx.force_rebuild and await self._is_image_in_registry( remote_image): if bake_image.status == ImageStatus.PENDING: await image_storage.update(status=ImageStatus.CACHED, ) return await self._run_task(full_id, task) if bake_image.status == ImageStatus.PENDING: await self._start_image_build(bake_image, image_ctx, image_storage) return None # wait for next pulling interval elif bake_image.status == ImageStatus.BUILDING: return None # wait for next pulling interval else: # Image is already build (maybe with an error, just try to run a job) return await self._run_task(full_id, task) async def _run_task(self, full_id: FullID, task: Task) -> StorageTask: log.debug(f"BatchExecutor: starting job for {full_id}") preset_name = task.preset if preset_name is None: preset_name = next(iter(self._client.config_presets)) envs = self._client.parse_envs( [f"{k}={v}" for k, v in task.env.items()]) volumes_parsed = self._client.parse_volumes(task.volumes) volumes = list(volumes_parsed.volumes) http_auth = task.http_auth if http_auth is None: http_auth = HTTPPort.requires_auth job = await self._client.job_start( shm=True, tty=False, image=self._client.parse_remote_image(task.image), preset_name=preset_name, entrypoint=task.entrypoint, command=task.cmd, http=HTTPPort(task.http_port, http_auth) if task.http_port else None, env=envs.env, secret_env=envs.secret_env, volumes=volumes, working_dir=str(task.workdir) if task.workdir else None, secret_files=list(volumes_parsed.secret_files), disk_volumes=list(volumes_parsed.disk_volumes), name=task.name, tags=list(task.tags), description=task.title, life_span=task.life_span, schedule_timeout=task.schedule_timeout, pass_config=bool(task.pass_config), ) await self._add_resource(job.uri) return await self._create_task( yaml_id=full_id, raw_id=job.id, status=TaskStatusItem( when=job.history.created_at or datetime.now(timezone.utc), status=TaskStatus.PENDING, ), ) async def _start_image_build( self, bake_image: BakeImage, image_ctx: ImageCtx, image_storage: BakeImageStorage, ) -> None: log.debug(f"BatchExecutor: starting image build for {bake_image}") context = bake_image.context_on_storage or image_ctx.context dockerfile_rel = bake_image.dockerfile_rel or image_ctx.dockerfile_rel if context is None or dockerfile_rel is None: if context is None and dockerfile_rel is None: error = "context and dockerfile not specified" elif context is None and dockerfile_rel is not None: error = "context not specified" else: error = "dockerfile not specified" raise Exception( f"Failed to build image '{bake_image.ref}': {error}") cmd = ["neuro-extras", "image", "build"] cmd.append(f"--file={dockerfile_rel}") for arg in image_ctx.build_args: cmd.append(f"--build-arg={arg}") for vol in image_ctx.volumes: cmd.append(f"--volume={vol}") for k, v in image_ctx.env.items(): cmd.append(f"--env={k}={v}") if image_ctx.force_rebuild: cmd.append("--force-overwrite") if image_ctx.build_preset is not None: cmd.append(f"--preset={image_ctx.build_preset}") cmd.append(str(context)) cmd.append(str(bake_image.ref)) builder_job_id = await self._run_builder_job(*cmd) await image_storage.update( builder_job_id=builder_job_id, status=ImageStatus.BUILDING, ) self._progress.log("Image", fmt_id(bake_image.ref), "is", ImageStatus.BUILDING) self._progress.log(" builder job:", fmt_raw_id(builder_job_id)) async def _process_running_builds(self) -> None: log.debug(f"BatchExecutor: checking running build") async for image in self._bake_storage.list_bake_images(): log.debug(f"BatchExecutor: processing image {image}") image_storage = self._bake_storage.bake_image(id=image.id) if image.status == ImageStatus.BUILDING: assert image.builder_job_id descr = await self._client.job_status(image.builder_job_id) log.debug(f"BatchExecutor: got job description {descr}") if descr.status == JobStatus.SUCCEEDED: await image_storage.update(status=ImageStatus.BUILT, ) self._progress.log("Image", fmt_id(image.ref), "is", ImageStatus.BUILT) elif descr.status in TERMINATED_JOB_STATUSES: await image_storage.update( status=ImageStatus.BUILD_FAILED, ) self._progress.log("Image", fmt_id(image.ref), "is", ImageStatus.BUILD_FAILED) async def _create_project_role(self, project_role: str) -> None: if self._is_projet_role_created: return log.debug(f"BatchExecutor: creating project role {project_role}") try: await self._client.user_add(project_role) except AuthorizationError: log.debug(f"BatchExecutor: AuthorizationError for create" f" project role {project_role}") pass # We have no permissions to create role -- # assume that this is shared project and # current user is not the owner except IllegalArgumentError as e: if "already exists" not in str(e): raise self._is_projet_role_created = True async def _add_resource(self, uri: URL) -> None: log.debug(f"BatchExecutor: adding resource {uri}") project_role = self._project_role if project_role is None: return await self._create_project_role(project_role) permission = Permission(uri, Action.WRITE) try: await self._client.user_share(project_role, permission) except ValueError: self._progress.log(f"[red]Cannot share [b]{uri!s}[/b]")
def processor( site_data: dict, timeout: int, use_proxy: bool, result_printer: ResultPrinter, task_id: TaskID, progress: Progress, ) -> dict: """ the main processor for mori. Args: result_printer: site_data: a dictionary in site_data_list which read from json file timeout: Time in seconds to wait before timing out request use_proxy: not use proxy while this value is False. Of course, proxy field in the config file should have a value. result_printer: when processor finish , result_printer will be invoked to output result. task_id: it is id of main_progress, when processor finish, main_progress while step 1. progress: main progress. Returns: Dictionary containing results from report. 'name': api name in configuration, 'url': api url in configuration, 'base_url': api base_url in configuration, 'resp_text': raw response.text from url, if length of resp_text > 500, it wont`t display on console, and you can add --xls to see detail in *.xls file. 'status_code': response status_code, 'time(s)': time in seconds spend on request, 'error_text': error_text, 'exception_text': exception_text, 'check_result': 'OK' , 'Damage' or 'Unknown'. Default is 'Unknown'. 'traceback': Instance of Traceback. 'check_results': check_results, each result of all regexes. 'remark': field hold on """ rel_result, result, monitor_id = {}, {}, None session = requests.Session() max_retries = 5 if progress: monitor_id = progress.add_task( f'{site_data["name"]} (retry)', visible=False, total=max_retries ) # progress.update(monitor_id, advance=-max_retries) check_result, check_results = "Damage", [] r, resp_text = None, "" try: for retries in range(max_retries): check_result = "Damage" traceback, r, resp_text = None, None, "" error_text, exception_text, check_results = "", "", {} check_result = "Unknown" headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; " "rv:55.0) Gecko/20100101 Firefox/55.0", } if site_data.get("headers"): if isinstance(site_data.get("headers"), dict): headers.update(site_data.get("headers")) Proxy.set_proxy_url( site_data.get("proxy"), site_data.get("strict_proxy"), use_proxy, headers, ) if site_data.get("antispider") and site_data.get("data"): try: import importlib package = importlib.import_module( "antispider." + site_data["antispider"] ) Antispider = getattr(package, "Antispider") site_data["data"], headers = Antispider( site_data["data"], headers ).processor() except Exception as _e: site_data["single"] = True site_data["exception_text"] = _e site_data["traceback"] = Traceback() raise Exception("antispider failed") try: proxies = Proxy.get_proxy() except Exception as _e: site_data["single"] = True site_data["exception_text"] = _e site_data["traceback"] = Traceback() raise Exception("all of six proxies can`t be used") if site_data.get("single"): error_text = site_data["error_text"] exception_text = site_data["exception_text"] traceback = site_data["traceback"] else: ( r, error_text, exception_text, check_results, check_result, traceback, resp_text, ) = get_response(session, site_data, headers, timeout, proxies) if error_text and retries + 1 < max_retries: if progress: progress.update( monitor_id, advance=1, visible=True, refresh=True ) continue result = { "name": site_data["name"], "url": site_data["url"], "base_url": site_data.get("base_url", ""), "resp_text": resp_text if len(resp_text) < 500 else "too long, and you can add --xls " "to see detail in *.xls file", "status_code": getattr(r, "status_code", "failed"), "time(s)": float(r.elapsed.total_seconds()) if r else -1.0, "error_text": str(error_text), "exception_text": exception_text, "check_result": check_result, "traceback": traceback, "check_results": check_results, "remark": site_data.get("remark", ""), } rel_result = dict(result.copy()) rel_result["resp_text"] = resp_text break except Exception as error: result = { "name": site_data["name"], "url": site_data["url"], "base_url": site_data.get("base_url", ""), "resp_text": resp_text if len(resp_text) < 500 else "too long, and you can add --xls " "to see detail in *.xls file", "status_code": getattr(r, "status_code", "failed") if r else "none", "time(s)": float(r.elapsed.total_seconds()) if r else -1.0, "error_text": str(error) or "site handler error", "check_result": check_result, "traceback": Traceback(), "check_results": check_results, "remark": site_data.get("remark", ""), } rel_result = dict(result.copy()) if result_printer: progress.update(task_id, advance=1, refresh=True) result_printer.printer(result) progress.remove_task(monitor_id) return rel_result
def fit(epochs: int, lr: float, model: nn.Module, loss_fn, opt, train_dl, val_dl, debug_run: bool = False, debug_num_batches: int = 5): model.to(device) opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # mb = master_bar(range(epochs)) progress_bar = Progress() # mb.write(['epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time'], table=True) progress_bar.start() header = [ 'epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time' ] # , 'val_time', 'train_time'] progress_bar.print(' '.join([f"{value:>9}" for value in header])) main_job = progress_bar.add_task('fit...', total=epochs) # for epoch in mb: for epoch in range(epochs): progress_bar.tasks[ main_job].description = f"ep {epoch + 1} of {epochs}" # mb.main_bar.comment = f"ep {epoch + 1} of {epochs}" model.train() start_time = time.time() # for batch_num, batch in enumerate(progress_bar(train_dl, parent=mb)): len_train_dl = len(train_dl) train_job = progress_bar.add_task('train', total=len_train_dl) for batch_num, batch in enumerate(train_dl): progress_bar._tasks[ train_job].description = f"batch {batch_num}/{len_train_dl}" if debug_run and batch_num == debug_num_batches: break loss = loss_batch(model, loss_fn, batch) # mb.child.comment = f"loss {loss:0.4f}" loss.backward() opt.step() opt.zero_grad() progress_bar.update(train_job, advance=1) train_time = time.time() - start_time model.eval() len_val_dl = len(val_dl) val_job = progress_bar.add_task('validate...', total=len_val_dl) with torch.no_grad(): valid_loss = [] # for batch_num, batch in enumerate(progress_bar(val_dl, parent=mb)): for batch_num, batch in enumerate(val_dl): if debug_run and batch_num == debug_num_batches: break valid_loss.append(loss_batch(model, loss_fn, batch).item()) valid_loss = sum(valid_loss) progress_bar.update(val_job, advance=1) epoch_time = time.time() - start_time # mb.write([str(epoch + 1), f'{loss:0.4f}', f'{valid_loss / len(val_dl):0.4f}', # format_time(epoch_time), format_time(train_time), format_time(epoch_time - train_time)], table=True) to_progress_bar = [ f"{epoch + 1}", f"{loss:0.4f}", f"{valid_loss:0.4f}", f"{format_time(epoch_time)}", f"{format_time(train_time)}" ] # , f"{format_time(val_time)}"] progress_bar.print(' '.join( [f"{value:>9}" for value in to_progress_bar])) progress_bar.update(main_job, advance=1) progress_bar.remove_task(train_job) progress_bar.remove_task(val_job) progress_bar.remove_task(main_job) progress_bar.stop()
def do_one_ping( host, progress: Progress, cc: Optional[str] = None, seq_offset=0, count=8, interval=0.1, timeout=2, id=PID, source=None, **kwargs, ) -> PingResult: """ :raises NameLookupError: If you pass a hostname or FQDN in parameters and it does not exist or cannot be resolved. :raises SocketPermissionError: If the privileges are insufficient to create the socket. :raises SocketAddressError: If the source address cannot be assigned to the socket. :raises ICMPSocketError: If another error occurs. See the `ICMPv4Socket` or `ICMPv6Socket` class for details. """ address = resolve(host) if isinstance(address, list): address = address[0] log = logging.getLogger("rich") task_id = progress.add_task(host, total=count) # on linux `privileged` must be True if is_ipv6_address(address): sock = ICMPv6Socket(address=source, privileged=True) else: sock = ICMPv4Socket(address=source, privileged=True) times = [] for sequence in range(count): progress.update(task_id, advance=1) request = ICMPRequest( destination=address, id=id, sequence=sequence + seq_offset, **kwargs ) try: sock.send(request) reply = sock.receive(request, timeout) reply.raise_for_status() round_trip_time = (reply.time - request.time) * 1000 times.append(round_trip_time) if sequence < count - 1: sleep(interval) except ICMPLibError as e: log.error(f"接收 {host} Ping 返回信息失败: {e}") progress.remove_task(task_id=task_id) sock.close() log.info(f"{host} Ping 检测已经完成") return PingResult(host=host, cc=cc, count=count, times=times)
class DeviantArtDownloader: def __init__(self, client_id, client_secret): self.api = Api(client_id, client_secret) self.progress = Progress( BarColumn(bar_width=None), "[progress.percentage]{task.percentage:>3.1f}%", DownloadColumn(), TransferSpeedColumn(), "[bold blue]{task.fields[filename]}", ) self.all_t = self.progress.add_task('All', filename='All', start=0) self.total_length = 0; def download_worker(self, task_id, url, path): with open(path, 'wb') as f, GET(url, stream=True) as rq: length = int(rq.headers.get('Content-Length', 0)) self.progress.start_task(task_id) self.progress.update(task_id, total=length) self.total_length += length self.progress.update(self.all_t, total=self.total_length) for chunk in rq.iter_content(chunk_size=4096): f.write(chunk) self.progress.update(task_id, advance=len(chunk)) self.progress.update(self.all_t, advance=len(chunk)) return task_id def search_content(self, tag, max_items=-1): n_items = 0 offset = 0 while True: data = self.api.browse('tags', tag=tag, offset=offset) for item in data['results']: yield item n_items += 1 if n_items > max_items and max_items > 0: return if not data['has_more']: break offset = data['next_offset'] @staticmethod def _make_filename(item): src = item.content['src'] ext = splitext(urlparse(src).path)[1] return splitpath(item.url)[1] + ext def download(self, tag, out_dir='.', max_items=-1, max_workers=8, list_path=None): if not exists(out_dir): mkdir(out_dir) with self.progress, ThreadPoolExecutor(max_workers=max_workers) as pool: self.progress.start_task(self.all_t) futures = [] for item in self.search_content(tag, max_items): if list_path: with open(list_path, 'a') as flist: flist.write(item.url + '\n') if not item.content: continue filename = join(out_dir, self._make_filename(item)) task_id = self.progress.add_task( 'download', filename=item.title, start=0) url = item.content['src'] f = pool.submit(self.download_worker, task_id, url, filename) futures.append(f) while len(futures) >= max_workers: for f in futures: if f.done(): futures.remove(f) self.progress.remove_task(f.result()) sleep(0.1)
def video_thread_func( device: torch.device, num_lock: int, multi_gpu: bool, input: Path, start_frame: int, end_frame: int, num_frames: int, progress: Progress, task_upscaled_id: TaskID, ai_upscaled_path: Path, fps: int, quality: float, ffmpeg_params: str, deinterpaint: DeinterpaintOptions, diff_mode: bool, ssim: bool, min_ssim: float, chunk_size: int, padding_size: int, scale: int, upscale: Upscale, config: configparser.ConfigParser, scenes_ini: Path, ): log = logging.getLogger() video_reader: FfmpegFormat.Reader = imageio.get_reader( str(input.absolute())) start_time = time.process_time() last_frame = None last_frame_ai = None current_frame = None frames_diff: List[Optional[FrameDiff]] = [] video_reader.set_image_index(start_frame - 1) start_frame_str = str(start_frame).zfill(len(str(num_frames))) end_frame_str = str(end_frame).zfill(len(str(num_frames))) task_scene_desc = f'Scene [green]"{start_frame_str}_{end_frame_str}"[/]' if multi_gpu and len(upscale.devices) > 1: if device.type == "cuda": device_name = torch.cuda.get_device_name(device.index) else: device_name = "CPU" task_scene_desc += f" ({device_name})" task_scene_id = progress.add_task( description=task_scene_desc, total=end_frame - start_frame + 1, completed=0, refresh=True, ) video_writer_params = {"quality": quality, "macro_block_size": None} if ffmpeg_params: if "-crf" in ffmpeg_params: del video_writer_params["quality"] video_writer_params["output_params"] = ffmpeg_params.split() video_writer: FfmpegFormat.Writer = imageio.get_writer( str( ai_upscaled_path.joinpath( f"{start_frame_str}_{end_frame_str}.mp4").absolute()), fps=fps, **video_writer_params, ) duplicated_frames = 0 total_duplicated_frames = 0 for current_frame_idx in range(start_frame, end_frame + 1): frame = video_reader.get_next_data() if deinterpaint is not None: for i in range( 0 if deinterpaint == DeinterpaintOptions.even else 1, frame.shape[0], 2): frame[i:i + 1] = (0, 255, 0) # (B, G, R) if not diff_mode: if last_frame is not None and are_same_imgs( last_frame, frame, ssim, min_ssim): frame_ai = last_frame_ai if duplicated_frames == 0: start_duplicated_frame = current_frame_idx - 1 duplicated_frames += 1 else: frame_ai = upscale.image(frame, device, multi_gpu_release_device=False) if duplicated_frames != 0: start_duplicated_frame_str = str( start_duplicated_frame).zfill(len(str(num_frames))) current_frame_idx_str = str(current_frame_idx - 1).zfill( len(str(num_frames))) log.info( f"Detected {duplicated_frames} duplicated frame{'' if duplicated_frames==1 else 's'} ({start_duplicated_frame_str}-{current_frame_idx_str})" ) total_duplicated_frames += duplicated_frames duplicated_frames = 0 video_writer.append_data(frame_ai) last_frame = frame last_frame_ai = frame_ai progress.advance(task_upscaled_id) progress.advance(task_scene_id) else: if current_frame is None: current_frame = frame else: frame_diff = get_diff_frame(current_frame, frame, chunk_size, padding_size, ssim, min_ssim) if ( frame_diff is None ): # the frame is equal to current_frame, the best scenario!!! frames_diff.append(frame_diff) else: h_diff, w_diff, c_diff = frame_diff.frame.shape h, w, c = current_frame.shape if w * h > w_diff * h_diff: # TODO difference of size > 20% frames_diff.append(frame_diff) else: current_frame_ai = upscale.image( current_frame, device, multi_gpu_release_device=False) video_writer.append_data(current_frame_ai) progress.advance(task_upscaled_id) progress.advance(task_scene_id) current_frame = frame for frame_diff in frames_diff: if frame_diff is None: frame_ai = current_frame_ai else: diff_ai = upscale.image( frame_diff.frame, device, multi_gpu_release_device=False, ) frame_diff_ai = frame_diff frame_diff_ai.frame = diff_ai frame_ai = get_frame( current_frame_ai, frame_diff_ai, scale, chunk_size, padding_size, ) video_writer.append_data(frame_ai) progress.advance(task_upscaled_id) progress.advance(task_scene_id) frames_diff = [] if diff_mode: if len(frames_diff) > 0: current_frame_ai = upscale.image(current_frame, device, multi_gpu_release_device=False) video_writer.append_data(current_frame_ai) progress.advance(task_upscaled_id) progress.advance(task_scene_id) for frame_diff in frames_diff: if frame_diff is None: frame_ai = current_frame else: diff_ai = upscale.image(frame_diff.frame, device, multi_gpu_release_device=False) frame_diff_ai = frame_diff frame_diff_ai.frame = diff_ai frame_ai = get_frame( current_frame_ai, frame_diff_ai, scale, chunk_size, padding_size, ) video_writer.append_data(frame_ai) progress.advance(task_upscaled_id) progress.advance(task_scene_id) current_frame = None frames_diff = [] elif current_frame is not None: current_frame_ai = upscale.image(current_frame, device, multi_gpu_release_device=False) video_writer.append_data(current_frame_ai) progress.advance(task_upscaled_id) progress.advance(task_scene_id) if duplicated_frames != 0: start_duplicated_frame_str = str(start_duplicated_frame).zfill( len(str(num_frames))) current_frame_idx_str = str(current_frame_idx - 1).zfill( len(str(num_frames))) log.info( f"Detected {duplicated_frames} duplicated frame{'' if duplicated_frames==1 else 's'} ({start_duplicated_frame_str}-{current_frame_idx_str})" ) total_duplicated_frames += duplicated_frames duplicated_frames = 0 video_writer.close() task_scene = next(task for task in progress.tasks if task.id == task_scene_id) config.set(f"{start_frame_str}_{end_frame_str}", "upscaled", "True") config.set( f"{start_frame_str}_{end_frame_str}", "duplicated_frames", f"{total_duplicated_frames}", ) finished_speed = task_scene.finished_speed or task_scene.speed or 0.01 config.set( f"{start_frame_str}_{end_frame_str}", "average_fps", f"{finished_speed:.2f}", ) with open(scenes_ini, "w") as configfile: config.write(configfile) log.info( f"Frames from {str(start_frame).zfill(len(str(num_frames)))} to {str(end_frame).zfill(len(str(num_frames)))} upscaled in {precisedelta(dt.timedelta(seconds=time.process_time() - start_time))}" ) if total_duplicated_frames > 0: total_frames = end_frame - (start_frame - 1) seconds_saved = (((1 / finished_speed * total_frames) - ( total_duplicated_frames * 0.04) # 0.04 seconds per duplicate frame ) / (total_frames - total_duplicated_frames) * total_duplicated_frames) log.info( f"Total number of duplicated frames from {str(start_frame).zfill(len(str(num_frames)))} to {str(end_frame).zfill(len(str(num_frames)))}: {total_duplicated_frames} (saved ≈ {precisedelta(dt.timedelta(seconds=seconds_saved))})" ) progress.remove_task(task_scene_id) if multi_gpu: upscale.devices[device][num_lock].release()