Example #1
0
def make_progress() -> Progress:
    _time = 0.0

    def fake_time():
        nonlocal _time
        try:
            return _time
        finally:
            _time += 1

    console = Console(
        file=io.StringIO(),
        force_terminal=True,
        color_system="truecolor",
        width=80,
        legacy_windows=False,
    )
    progress = Progress(console=console, get_time=fake_time, auto_refresh=False)
    task1 = progress.add_task("foo")
    task2 = progress.add_task("bar", total=30)
    progress.advance(task2, 16)
    task3 = progress.add_task("baz", visible=False)
    task4 = progress.add_task("egg")
    progress.remove_task(task4)
    task4 = progress.add_task("foo2", completed=50, start=False)
    progress.stop_task(task4)
    progress.start_task(task4)
    progress.update(
        task4, total=200, advance=50, completed=200, visible=True, refresh=True
    )
    progress.stop_task(task4)
    return progress
Example #2
0
def make_progress() -> Progress:
    console = Console(file=io.StringIO(), force_terminal=True)
    progress = Progress(console=console)
    task1 = progress.add_task("foo")
    task2 = progress.add_task("bar", 30)
    progress.advance(task2, 16)
    task3 = progress.add_task("baz", visible=False)
    task4 = progress.add_task("egg")
    progress.remove_task(task4)
    task4 = progress.add_task("foo2", completed=50, start=False)
    progress.start_task(task4)
    progress.update(
        task4, total=200, advance=50, completed=200, visible=True, refresh=True
    )
    return progress
Example #3
0
    def _bar(
        self,
        progress: Progress,
        description: str,
        total: Optional[float],
    ) -> Iterator[ProgressBar]:
        if total is None:
            # Indeterminate progress bar
            taskid = progress.add_task(description, start=False)
        else:
            taskid = progress.add_task(description, total=total)
        self._update_live()

        try:
            yield ProgressBar(progress, taskid)
        finally:
            progress.remove_task(taskid)
            self._update_live()
Example #4
0
def run():
    """
    Run the Console Interface for Unsilence
    :return: None
    """
    sys.tracebacklimit = 0

    args = parse_arguments()
    console = Console()

    if args.debug:
        sys.tracebacklimit = 1000

    if args.output_file.exists() and not args.non_interactive_mode:
        if not choice_dialog(
                console, "File already exists. Overwrite?", default=False):
            return

    args_dict = vars(args)

    argument_list_for_silence_detect = [
        "silence_level", "silence_time_threshold", "short_interval_threshold",
        "stretch_time"
    ]

    argument_dict_for_silence_detect = {
        key: args_dict[key]
        for key in argument_list_for_silence_detect if key in args_dict.keys()
    }

    argument_list_for_renderer = [
        "audio_only", "audible_speed", "silent_speed", "audible_volume",
        "silent_volume", "drop_corrupted_intervals", "threads"
    ]

    argument_dict_for_renderer = {
        key: args_dict[key]
        for key in argument_list_for_renderer if key in args_dict.keys()
    }

    progress = Progress()

    continual = Unsilence(args.input_file)

    with progress:

        def update_task(current_task):
            def handler(current_val, total):
                progress.update(current_task,
                                total=total,
                                completed=current_val)

            return handler

        silence_detect_task = progress.add_task("Calculating Intervals...",
                                                total=1)

        continual.detect_silence(
            on_silence_detect_progress_update=update_task(silence_detect_task),
            **argument_dict_for_silence_detect)

        progress.stop()
        progress.remove_task(silence_detect_task)

        print()

        estimated_time = continual.estimate_time(args.audible_speed,
                                                 args.silent_speed)
        console.print(pretty_time_estimate(estimated_time))

        print()

        if not args.non_interactive_mode:
            if not choice_dialog(
                    console, "Continue with these options?", default=True):
                return

        progress.start()
        rendering_task = progress.add_task("Rendering Intervals...", total=1)
        concat_task = progress.add_task("Combining Intervals...", total=1)

        continual.render_media(
            args.output_file,
            on_render_progress_update=update_task(rendering_task),
            on_concat_progress_update=update_task(concat_task),
            **argument_dict_for_renderer)

        progress.stop()

    console.print("\n[green]Finished![/green] :tada:")
    print()
Example #5
0
class RichProgressBar(ProgressBarBase):
    def __init__(self):
        super().__init__()
        self.pg = Progress(
            "[progress.description]{task.description}",
            BarColumn(),
            "[progress.percentage]{task.completed}/{task.total}",
            BarDictColumn(),
            transient=True,
            expand=True,
            refresh_per_second=2,
        )
        self._train_id = self._val_id = self._test_id = None

    def disable(self):
        self.pg.disable = True

    def enable(self):
        self.pg.disable = False

    def on_train_start(self, trainer, pl_module):
        super().on_train_start(trainer, pl_module)
        self.pg.start()

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                          dataloader_idx):
        super().on_test_batch_end(trainer, pl_module, outputs, batch,
                                  batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._test_id, completed=self.test_batch_idx, **bardic)
        if self.test_batch_idx >= self.total_test_batches:
            self.pg.stop_task(self._test_id)

    def on_test_epoch_start(self, trainer, pl_module) -> None:
        super().on_test_epoch_start(trainer, pl_module)
        if self._test_id is not None:
            self.pg.remove_task(self._test_id)
        self._test_id = self.pg.add_task("epoch T%03d" % trainer.current_epoch,
                                         total=self.total_test_batches)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                           dataloader_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch,
                                   batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._train_id,
                       completed=self.train_batch_idx,
                       **bardic)
        if self.train_batch_idx >= self.total_train_batches:
            self.pg.stop_task(self._train_id)

    def on_train_epoch_start(self, trainer, pl_module):
        super().on_train_epoch_start(trainer, pl_module)
        if self._train_id is not None:
            self.pg.remove_task(self._train_id)
        if self._val_id is not None:
            self.pg.remove_task(self._val_id)
        self._train_id = self.pg.add_task("epoch T%03d" %
                                          trainer.current_epoch,
                                          total=self.total_train_batches)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                batch_idx, dataloader_idx):
        super().on_validation_batch_end(trainer, pl_module, outputs, batch,
                                        batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._val_id, completed=self.val_batch_idx, **bardic)
        if self.val_batch_idx >= self.total_val_batches:
            self.pg.stop_task(self._val_id)

    def on_validation_epoch_start(self, trainer, pl_module) -> None:
        super().on_validation_epoch_start(trainer, pl_module)
        self._val_id = self.pg.add_task("epoch V%03d" % trainer.current_epoch,
                                        total=self.total_val_batches)

    def on_train_end(self, trainer, pl_module):
        super().on_train_end(trainer, pl_module)
        self.pg.stop()

    def print(self, *args, **kwargs):
        self.pg.console.print(*args, **kwargs)
Example #6
0
class RichOutput(AbstractChecksecOutput):
    def __init__(self, libc_detected: bool = False):
        """Init Rich Console and Table"""
        super().__init__(libc_detected)
        # init ELF table
        self.table_elf = Table(title="Checksec Results: ELF", expand=True)
        self.table_elf.add_column("File", justify="left", header_style="")
        self.table_elf.add_column("NX", justify="center")
        self.table_elf.add_column("PIE", justify="center")
        self.table_elf.add_column("Canary", justify="center")
        self.table_elf.add_column("Relro", justify="center")
        self.table_elf.add_column("RPATH", justify="center")
        self.table_elf.add_column("RUNPATH", justify="center")
        self.table_elf.add_column("Symbols", justify="center")
        if self._libc_detected:
            self.table_elf.add_column("FORTIFY", justify="center")
            self.table_elf.add_column("Fortified", justify="center")
            self.table_elf.add_column("Fortifiable", justify="center")
            self.table_elf.add_column("Fortify Score", justify="center")

        # init PE table
        self.table_pe = Table(title="Checksec Results: PE", expand=True)
        self.table_pe.add_column("File", justify="left", header_style="")
        self.table_pe.add_column("NX", justify="center")
        self.table_pe.add_column("Canary", justify="center")
        self.table_pe.add_column("ASLR", justify="center")
        self.table_pe.add_column("Dynamic Base", justify="center")
        self.table_pe.add_column("High Entropy VA", justify="center")
        self.table_pe.add_column("SEH", justify="center")
        self.table_pe.add_column("SafeSEH", justify="center")
        self.table_pe.add_column("Force Integrity", justify="center")
        self.table_pe.add_column("Control Flow Guard", justify="center")
        self.table_pe.add_column("Isolation", justify="center")
        self.table_pe.add_column("Authenticode", justify="center")

        # init console
        self.console = Console()

        # build progress bar
        self.process_bar = Progress(
            TextColumn("[bold blue]Processing...", justify="left"),
            BarColumn(bar_width=None),
            "{task.completed}/{task.total}",
            "•",
            "[progress.percentage]{task.percentage:>3.1f}%",
            console=self.console,
        )
        self.display_res_bar = Progress(
            BarColumn(bar_width=None),
            TextColumn("[bold blue]{task.description}", justify="center"),
            BarColumn(bar_width=None),
            console=self.console,
            transient=True,
        )
        self.enumerate_bar = Progress(
            TextColumn("[bold blue]Enumerating...", justify="center"),
            BarColumn(bar_width=None),
            console=self.console,
            transient=True,
        )

        self.process_task_id = None

    def __exit__(self, exc_type, exc_val, exc_tb):
        # cleanup the Rich progress bars
        if self.enumerate_bar is not None:
            self.enumerate_bar.stop()
        if self.process_bar is not None:
            self.process_bar.stop()
        if self.display_res_bar is not None:
            self.display_res_bar.stop()

    def enumerating_tasks_start(self):
        # start progress bar
        self.enumerate_bar.start()
        self.enumerate_bar.add_task("Enumerating", start=False)

    def enumerating_tasks_stop(self, total: int):
        super().enumerating_tasks_stop(total)
        self.enumerate_bar.stop()
        self.enumerate_bar = None

    def processing_tasks_start(self):
        # init progress bar
        self.process_bar.start()
        self.process_task_id = self.process_bar.add_task("Checking",
                                                         total=self.total)

    def add_checksec_result(self, filepath: Path,
                            checksec: Union[ELFChecksecData, PEChecksecData]):
        logging.debug("result for %s: %s", filepath, checksec)
        if isinstance(checksec, ELFChecksecData):
            row_res: List[str] = []
            # display results
            if not checksec.nx:
                nx_res = "[red]No"
            else:
                nx_res = "[green]Yes"
            row_res.append(nx_res)

            pie = checksec.pie
            if pie == PIEType.No:
                pie_res = f"[red]{pie.name}"
            elif pie == PIEType.DSO:
                pie_res = f"[yellow]{pie.name}"
            else:
                pie_res = "[green]Yes"
            row_res.append(pie_res)

            if not checksec.canary:
                canary_res = "[red]No"
            else:
                canary_res = "[green]Yes"
            row_res.append(canary_res)

            relro = checksec.relro
            if relro == RelroType.No:
                relro_res = f"[red]{relro.name}"
            elif relro == RelroType.Partial:
                relro_res = f"[yellow]{relro.name}"
            else:
                relro_res = f"[green]{relro.name}"
            row_res.append(relro_res)

            if checksec.rpath:
                rpath_res = "[red]Yes"
            else:
                rpath_res = "[green]No"
            row_res.append(rpath_res)

            if checksec.runpath:
                runpath_res = "[red]Yes"
            else:
                runpath_res = "[green]No"
            row_res.append(runpath_res)

            if checksec.symbols:
                symbols_res = "[red]Yes"
            else:
                symbols_res = "[green]No"
            row_res.append(symbols_res)

            # fortify results depend on having a Libc available
            if self._libc_detected:
                fortified_count = checksec.fortified
                if checksec.fortify_source:
                    fortify_source_res = "[green]Yes"
                else:
                    fortify_source_res = "[red]No"
                row_res.append(fortify_source_res)

                if fortified_count == 0:
                    fortified_res = "[red]No"
                else:
                    fortified_res = f"[green]{fortified_count}"
                row_res.append(fortified_res)

                fortifiable_count = checksec.fortifiable
                if fortified_count == 0:
                    fortifiable_res = "[red]No"
                else:
                    fortifiable_res = f"[green]{fortifiable_count}"
                row_res.append(fortifiable_res)

                if checksec.fortify_score == 0:
                    fortified_score_res = f"[red]{checksec.fortify_score}"
                elif checksec.fortify_score == 100:
                    fortified_score_res = f"[green]{checksec.fortify_score}"
                else:
                    fortified_score_res = f"[yellow]{checksec.fortify_score}"
                row_res.append(fortified_score_res)

            self.table_elf.add_row(str(filepath), *row_res)
        elif isinstance(checksec, PEChecksecData):
            if not checksec.nx:
                nx_res = "[red]No"
            else:
                nx_res = "[green]Yes"

            if not checksec.canary:
                canary_res = "[red]No"
            else:
                canary_res = "[green]Yes"

            if not checksec.aslr:
                aslr_res = "[red]No"
            else:
                aslr_res = "[green]Yes"

            if not checksec.dynamic_base:
                dynamic_base_res = "[red]No"
            else:
                dynamic_base_res = "[green]Yes"

            # this is only relevant is binary is 64 bits
            if checksec.machine == MACHINE_TYPES.AMD64:
                if not checksec.high_entropy_va:
                    entropy_va_res = "[red]No"
                else:
                    entropy_va_res = "[green]Yes"
            else:
                entropy_va_res = "/"

            if not checksec.seh:
                seh_res = "[red]No"
            else:
                seh_res = "[green]Yes"

            # only relevant if 32 bits
            if checksec.machine == MACHINE_TYPES.I386:
                if not checksec.safe_seh:
                    safe_seh_res = "[red]No"
                else:
                    safe_seh_res = "[green]Yes"
            else:
                safe_seh_res = "/"

            if not checksec.authenticode:
                auth_res = "[red]No"
            else:
                auth_res = "[green]Yes"

            if not checksec.force_integrity:
                force_integrity_res = "[red]No"
            else:
                force_integrity_res = "[green]Yes"

            if not checksec.guard_cf:
                guard_cf_res = "[red]No"
            else:
                guard_cf_res = "[green]Yes"

            if not checksec.isolation:
                isolation_res = "[red]No"
            else:
                isolation_res = "[green]Yes"

            self.table_pe.add_row(
                str(filepath),
                nx_res,
                canary_res,
                aslr_res,
                dynamic_base_res,
                entropy_va_res,
                seh_res,
                safe_seh_res,
                force_integrity_res,
                guard_cf_res,
                isolation_res,
                auth_res,
            )
        else:
            raise NotImplementedError

    def checksec_result_end(self):
        """Update progress bar"""
        self.process_bar.update(self.process_task_id, advance=1)

    def print(self):
        self.process_bar.stop()
        self.process_bar = None

        if self.table_elf.row_count > 0:
            with self.display_res_bar:
                task_id = self.display_res_bar.add_task(
                    "Displaying Results: ELF ...", start=False)
                self.console.print(self.table_elf)
                self.display_res_bar.remove_task(task_id)
        if self.table_pe.row_count > 0:
            with self.display_res_bar:
                task_id = self.display_res_bar.add_task(
                    "Displaying Results: PE ...", start=False)
                self.console.print(self.table_pe)
                self.display_res_bar.remove_task(task_id)

        self.display_res_bar.stop()
        self.display_res_bar = None
Example #7
0
class BatchExecutor:
    transient_progress: bool = False
    run_builder_job: Callable[...,
                              Awaitable[str]] = staticmethod(  # type: ignore
                                  start_image_build)

    def __init__(
        self,
        console: Console,
        flow: RunningBatchFlow,
        bake: Bake,
        attempt: Attempt,
        client: RetryReadNeuroClient,
        storage: AttemptStorage,
        bake_storage: BakeStorage,
        project_storage: ProjectStorage,
        *,
        polling_timeout: Optional[float] = 1,
        project_role: Optional[str] = None,
    ) -> None:
        self._progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
            TextColumn("[progress.remaining]{task.elapsed:.0f} sec"),
            console=console,
            auto_refresh=False,
            transient=self.transient_progress,
            redirect_stderr=True,
        )
        self._top_flow = flow
        self._bake = bake
        self._attempt = attempt
        self._client = client
        self._storage = storage
        self._project_storage = project_storage
        self._bake_storage = bake_storage
        self._graphs: Graph[RunningBatchBase[BaseBatchContext]] = Graph(
            flow.graph, flow)
        self._tasks_mgr = BakeTasksManager()
        self._is_cancelling = False
        self._project_role = project_role
        self._is_projet_role_created = False

        self._run_builder_job = self.run_builder_job

        # A note about default value:
        # AS: I have no idea what timeout is better;
        # too short value bombards servers,
        # too long timeout makes the waiting longer than expected
        # The missing events subsystem would be great for this task :)
        self._polling_timeout = polling_timeout
        self._bars: Dict[FullID, TaskID] = {}

    @classmethod
    @asynccontextmanager
    async def create(
        cls,
        console: Console,
        bake_id: str,
        client: Client,
        storage: Storage,
        *,
        polling_timeout: Optional[float] = 1,
        project_role: Optional[str] = None,
    ) -> AsyncIterator["BatchExecutor"]:
        storage = storage.with_retry_read()

        console.log("Fetch bake data")
        bake_storage = storage.bake(id=bake_id)
        bake = await bake_storage.get()
        if bake.last_attempt:
            attempt = bake.last_attempt
            attempt_storage = bake_storage.attempt(id=bake.last_attempt.id)
        else:
            attempt_storage = bake_storage.last_attempt()
            attempt = await attempt_storage.get()

        console.log("Fetch configs metadata")
        flow = await get_running_flow(bake, client, bake_storage,
                                      attempt.configs_meta)

        console.log(f"Execute attempt #{attempt.number}")

        ret = cls(
            console=console,
            flow=flow,
            bake=bake,
            attempt=attempt,
            client=RetryReadNeuroClient(client),
            storage=attempt_storage,
            bake_storage=bake_storage,
            project_storage=storage.project(id=bake.project_id),
            polling_timeout=polling_timeout,
            project_role=project_role,
        )
        ret._start()
        try:
            yield ret
        finally:
            ret._stop()

    def _start(self) -> None:
        pass

    def _stop(self) -> None:
        pass

    async def _refresh_attempt(self) -> None:
        self._attempt = await self._storage.get()

    def _only_completed_needs(
            self, needs: Mapping[str, ast.NeedsLevel]) -> AbstractSet[str]:
        return {
            task_id
            for task_id, level in needs.items()
            if level == ast.NeedsLevel.COMPLETED
        }

    # Exact task/action contexts helpers
    async def _get_meta(self, full_id: FullID) -> TaskMeta:
        prefix, tid = full_id[:-1], full_id[-1]
        flow = self._graphs.get_meta(full_id)
        needs = self._tasks_mgr.build_needs(
            prefix, self._only_completed_needs(flow.graph[tid]))
        state = self._tasks_mgr.build_state(prefix, await flow.state_from(tid))
        return await flow.get_meta(tid, needs=needs, state=state)

    async def _get_task(self, full_id: FullID) -> Task:
        prefix, tid = full_id[:-1], full_id[-1]
        flow = self._graphs.get_meta(full_id)
        needs = self._tasks_mgr.build_needs(
            prefix, self._only_completed_needs(flow.graph[tid]))
        state = self._tasks_mgr.build_state(prefix, await flow.state_from(tid))
        return await flow.get_task(prefix, tid, needs=needs, state=state)

    async def _get_local(self, full_id: FullID) -> LocalTask:
        prefix, tid = full_id[:-1], full_id[-1]
        flow = self._graphs.get_meta(full_id)
        needs = self._tasks_mgr.build_needs(
            prefix, self._only_completed_needs(flow.graph[tid]))
        return await flow.get_local(tid, needs=needs)

    async def _get_action(self, full_id: FullID) -> RunningBatchActionFlow:
        prefix, tid = full_id[:-1], full_id[-1]
        flow = self._graphs.get_meta(full_id)
        needs = self._tasks_mgr.build_needs(
            prefix, self._only_completed_needs(flow.graph[tid]))
        return await flow.get_action(tid, needs=needs)

    async def _get_image(self, bake_image: BakeImage) -> Optional[ImageCtx]:
        actions = []
        full_id_to_image: Dict[FullID, ImageCtx] = {}
        for yaml_def in bake_image.yaml_defs:
            *prefix, image_id = yaml_def
            actions.append(".".join(prefix))
            try:
                flow = self._graphs.get_graph_data(tuple(prefix))
            except KeyError:
                pass
            else:
                full_id_to_image[yaml_def] = flow.images[image_id]
        image_ctx = None
        for _it in full_id_to_image.values():
            image_ctx = _it  # Return any context
            break
        if not all(image_ctx == it for it in full_id_to_image.values()):
            locations_str = "\n".join(f" - {'.'.join(full_id)}"
                                      for full_id in full_id_to_image.keys())
            self._progress.log(
                f"[red]Warning:[/red] definitions for image with"
                f" ref [b]{bake_image.ref}[/b] differ:\n"
                f"{locations_str}")
        if len(full_id_to_image.values()) == 0:
            actions_str = ", ".join(f"'{it}'" for it in actions)
            self._progress.log(
                f"[red]Warning:[/red] image with ref "
                f"[b]{bake_image.ref}[/b] is referred " +
                (f"before action {actions_str} "
                 "that holds its definition was able to run" if len(actions) ==
                 1 else f"before actions {actions_str} "
                 "that hold its definition were able to run"))
        return image_ctx

    # Graph helpers

    async def _embed_action(self, full_id: FullID) -> None:
        action = await self._get_action(full_id)
        self._graphs.embed(full_id, action.graph, action)
        task_id = self._progress.add_task(
            ".".join(full_id),
            completed=0,
            total=self._graphs.get_size(full_id) * 3,
        )
        self._bars[full_id] = task_id

    def _advance_progress_bar(self, old_task: Optional[StorageTask],
                              task: StorageTask) -> None:
        def _state_to_progress(status: TaskStatus) -> int:
            if status.is_pending:
                return 1
            if status.is_running:
                return 2
            return 3

        if old_task:
            advance = _state_to_progress(task.status) - _state_to_progress(
                old_task.status)
        else:
            advance = _state_to_progress(task.status)
        progress_bar_id = self._bars[task.yaml_id[:-1]]
        self._progress.update(progress_bar_id, advance=advance, refresh=True)

    def _update_graph_coloring(self, task: StorageTask) -> None:
        if task.status.is_running or task.status.is_finished:
            self._graphs.mark_running(task.yaml_id)
        if task.status.is_finished:
            self._graphs.mark_done(task.yaml_id)

    async def _log_task_status_change(self, task: StorageTask) -> None:
        flow = self._graphs.get_meta(task.yaml_id)
        in_flow_id = task.yaml_id[-1]
        if await flow.is_action(in_flow_id):
            log_args = ["Action"]
        elif await flow.is_local(in_flow_id):
            log_args = ["Local action"]
        else:
            log_args = ["Task"]
        log_args.append(fmt_id(task.yaml_id))
        if task.raw_id:
            log_args.append(fmt_raw_id(task.raw_id))
        log_args += ["is", fmt_status(task.status)]
        if task.status.is_finished:
            if task.outputs:
                log_args += ["with following outputs:"]
            else:
                log_args += ["(no outputs)"]
        self._progress.log(*log_args)
        if task.outputs:
            for key, value in task.outputs.items():
                self._progress.log(f"  {key}: {value}")

    async def _create_task(
        self,
        yaml_id: FullID,
        raw_id: Optional[str],
        status: Union[TaskStatusItem, TaskStatus],
        outputs: Optional[Mapping[str, str]] = None,
        state: Optional[Mapping[str, str]] = None,
    ) -> StorageTask:
        task = await self._storage.create_task(yaml_id, raw_id, status,
                                               outputs, state)
        self._advance_progress_bar(None, task)
        self._update_graph_coloring(task)
        await self._log_task_status_change(task)
        self._tasks_mgr.update(task)
        return task

    async def _update_task(
        self,
        yaml_id: FullID,
        *,
        outputs: Union[Optional[Mapping[str, str]], Type[_Unset]] = _Unset,
        state: Union[Optional[Mapping[str, str]], Type[_Unset]] = _Unset,
        new_status: Optional[Union[TaskStatusItem, TaskStatus]] = None,
    ) -> StorageTask:
        old_task = self._tasks_mgr.tasks[yaml_id]
        task = await self._storage.task(id=old_task.id).update(
            outputs=outputs,
            state=state,
            new_status=new_status,
        )
        self._advance_progress_bar(old_task, task)
        self._update_graph_coloring(task)
        await self._log_task_status_change(task)
        self._tasks_mgr.update(task)
        return task

    # New tasks processers

    async def _skip_task(self, full_id: FullID) -> None:
        log.debug(f"BatchExecutor: skipping task {full_id}")
        await self._create_task(
            yaml_id=full_id,
            status=TaskStatus.SKIPPED,
            raw_id=None,
            outputs={},
            state={},
        )

    async def _process_local(self, full_id: FullID) -> None:
        raise ValueError("Processing of local actions is not supported")

    async def _process_action(self, full_id: FullID) -> None:
        log.debug(f"BatchExecutor: processing action {full_id}")
        await self._create_task(
            yaml_id=full_id,
            status=TaskStatus.PENDING,
            raw_id=None,
        )
        await self._embed_action(full_id)

    async def _process_task(self, full_id: FullID) -> None:
        log.debug(f"BatchExecutor: processing task {full_id}")
        task = await self._get_task(full_id)

        # Check is task fits max_parallel
        for n in range(1, len(full_id) + 1):
            node = full_id[:n]
            prefix = node[:-1]
            if self._tasks_mgr.count_unfinished(
                    prefix) >= task.strategy.max_parallel:
                return

        cache_strategy = task.cache.strategy
        storage_task: Optional[StorageTask] = None
        if cache_strategy == ast.CacheStrategy.DEFAULT:
            try:
                cache_entry = await self._project_storage.cache_entry(
                    task_id=full_id,
                    batch=self._bake.batch,
                    key=task.caching_key).get()
            except ResourceNotFound:
                pass
            else:
                now = datetime.now(timezone.utc)
                eol = cache_entry.created_at + timedelta(task.cache.life_span)
                if now < eol:
                    storage_task = await self._create_task(
                        yaml_id=full_id,
                        raw_id=cache_entry.raw_id,
                        status=TaskStatus.CACHED,
                        outputs=cache_entry.outputs,
                        state=cache_entry.state,
                    )

        if storage_task is None:
            await self._start_task(full_id, task)

    async def _load_previous_run(self) -> None:
        log.debug(f"BatchExecutor: loading previous run")
        # Loads tasks that previous executor run processed.
        # Do not handles fail fast option.

        root_id = ()
        task_id = self._progress.add_task(
            f"<{self._bake.batch}>",
            completed=0,
            total=self._graphs.get_size(root_id) * 3,
        )
        self._bars[root_id] = task_id

        tasks = {
            task.yaml_id: task
            async for task in self._storage.list_tasks()
        }

        def _set_task_info(task: StorageTask) -> None:
            self._update_graph_coloring(task)
            self._advance_progress_bar(None, task)
            self._tasks_mgr.update(task)

        # Rebuild data in graph and task_mgr
        while tasks.keys() != self._tasks_mgr.tasks.keys():
            log.debug(f"BatchExecutor: rebuilding data...")
            for full_id, ctx in self._graphs.get_ready_with_meta():
                if full_id not in tasks or full_id in self._tasks_mgr.tasks:
                    continue
                task = tasks[full_id]
                if await ctx.is_action(full_id[-1]):
                    await self._embed_action(full_id)
                    if task.status.is_pending:
                        _set_task_info(task)
                else:
                    _set_task_info(task)
            for full_id in self._graphs.get_ready_to_mark_running_embeds():
                task = tasks[full_id]
                if task.status.is_finished:
                    self._graphs.mark_running(full_id)
                elif task.status.is_running:
                    _set_task_info(task)
            for full_id in self._graphs.get_ready_to_mark_done_embeds():
                task = tasks[full_id]
                if task.status.is_finished:
                    _set_task_info(task)

    async def _should_continue(self) -> bool:
        return self._graphs.is_active

    async def run(self) -> TaskStatus:
        with self._progress:
            try:
                result = await self._run()
            except (KeyboardInterrupt, asyncio.CancelledError):
                await self._cancel_unfinished()
                result = TaskStatus.CANCELLED
            except Exception as exp:
                await self.log_error(exp)
                await self._cancel_unfinished()
                result = TaskStatus.FAILED
        await self._finish_run(result)
        return result

    async def log_error(self, exp: Exception) -> None:
        if isinstance(exp, EvalError):
            self._progress.log(f"[red][b]ERROR:[/b] {exp}[/red]")
        else:
            # Looks like this is some bug in our code, so we should print stacktrace
            # for easier debug
            self._progress.console.print_exception(show_locals=True)
            self._progress.log(
                "[red][b]ERROR:[/b] Some unknown error happened. Please report "
                "an issue to https://github.com/neuro-inc/neuro-flow/issues/new "
                "with traceback printed above.[/red]")

    async def _run(self) -> TaskStatus:
        await self._load_previous_run()

        job_id = os.environ.get("NEURO_JOB_ID")
        if job_id:
            await self._storage.update(executor_id=job_id, )

        if self._attempt.result.is_finished:
            # If attempt is already terminated, just clean up tasks
            # and exit
            await self._cancel_unfinished()
            return self._attempt.result

        await self._storage.update(result=TaskStatus.RUNNING, )

        while await self._should_continue():
            _mgr = self._tasks_mgr

            def _fmt_debug(task: StorageTask) -> str:
                return ".".join(task.yaml_id) + f" - {task.status}"

            log.debug(
                f"BatchExecutor: main loop start:\n"
                f"tasks={', '.join(_fmt_debug(it) for it in _mgr.tasks.values())}\n"
            )
            for full_id, flow in self._graphs.get_ready_with_meta():
                tid = full_id[-1]
                if full_id in self._tasks_mgr.tasks:
                    continue  # Already started
                meta = await self._get_meta(full_id)
                if not meta.enable or (self._is_cancelling
                                       and meta.enable is not AlwaysT()):
                    # Make task started and immediatelly skipped
                    await self._skip_task(full_id)
                    continue
                if await flow.is_local(tid):
                    await self._process_local(full_id)
                elif await flow.is_action(tid):
                    await self._process_action(full_id)
                else:
                    await self._process_task(full_id)

            ok = await self._process_started()
            await self._process_running_builds()

            # Check for cancellation
            if not self._is_cancelling:
                await self._refresh_attempt()

                if not ok or self._attempt.result == TaskStatus.CANCELLED:
                    await self._cancel_unfinished()

            await asyncio.sleep(self._polling_timeout or 0)

        return self._accumulate_result()

    async def _finish_run(self, attempt_status: TaskStatus) -> None:
        if not self._attempt.result.is_finished:
            await self._storage.update(result=attempt_status)
        self._progress.print(
            Panel(
                f"[b]Attempt #{self._attempt.number}[/b] {fmt_status(attempt_status)}"
            ),
            justify="center",
        )
        if attempt_status != TaskStatus.SUCCEEDED:
            self._progress.log(
                "[blue b]Hint:[/blue b] you can restart bake starting "
                "from first failed task "
                "by the following command:\n"
                f"[b]neuro-flow restart {self._bake.id}[/b]")

    async def _cancel_unfinished(self) -> None:
        self._is_cancelling = True
        for task, raw_id in self._tasks_mgr.list_unfinished_raw_tasks():
            task_ctx = await self._get_task(task.yaml_id)
            if task_ctx.enable is not AlwaysT():
                self._progress.log(
                    f"Task {fmt_id(task.yaml_id)} is being killed")
                await self._client.job_kill(raw_id)
        async for image in self._bake_storage.list_bake_images():
            if image.status == ImageStatus.BUILDING:
                assert image.builder_job_id
                await self._client.job_kill(image.builder_job_id)

    async def _store_to_cache(self, task: StorageTask) -> None:
        log.debug(f"BatchExecutor: storing to cache {task.yaml_id}")
        task_ctx = await self._get_task(task.yaml_id)
        cache_strategy = task_ctx.cache.strategy
        if (cache_strategy == ast.CacheStrategy.NONE
                or task.status != TaskStatus.SUCCEEDED or task.raw_id is None
                or task.outputs is None or task.state is None):
            return

        await self._project_storage.create_cache_entry(
            key=task_ctx.caching_key,
            batch=self._bake.batch,
            task_id=task.yaml_id,
            raw_id=task.raw_id,
            outputs=task.outputs,
            state=task.state,
        )

    async def _process_started(self) -> bool:
        log.debug(f"BatchExecutor: processing started")
        # Process tasks
        for task, raw_id in self._tasks_mgr.list_unfinished_raw_tasks():
            assert task.raw_id, "list_unfinished_raw_tasks should list raw tasks"
            job_descr = await self._client.job_status(task.raw_id)
            log.debug(
                f"BatchExecutor: got description {job_descr} for task {task.yaml_id}"
            )
            if (job_descr.status in {JobStatus.RUNNING, JobStatus.SUCCEEDED}
                    and task.status.is_pending):
                assert (job_descr.history.started_at
                        ), "RUNNING jobs should have started_at"
                task = await self._update_task(
                    task.yaml_id,
                    new_status=TaskStatusItem(
                        when=job_descr.history.started_at,
                        status=TaskStatus.RUNNING,
                    ),
                )
            if job_descr.status in TERMINATED_JOB_STATUSES:
                log.debug(
                    f"BatchExecutor: processing logs for task {task.yaml_id}")
                async with CmdProcessor() as proc:
                    async for chunk in self._client.job_logs(raw_id):
                        async for line in proc.feed_chunk(chunk):
                            pass
                    async for line in proc.feed_eof():
                        pass
                log.debug(
                    f"BatchExecutor: finished processing logs for task {task.yaml_id}"
                )
                assert (job_descr.history.finished_at
                        ), "TERMINATED jobs should have finished_at"
                task = await self._update_task(
                    task.yaml_id,
                    new_status=TaskStatusItem(
                        when=job_descr.history.finished_at,
                        status=TaskStatus(job_descr.status),
                    ),
                    outputs=proc.outputs,
                    state=proc.states,
                )
                await self._store_to_cache(task)
                task_meta = await self._get_meta(task.yaml_id)
                if task.status != TaskStatus.SUCCEEDED and task_meta.strategy.fail_fast:
                    return False
        # Process batch actions
        for full_id in self._graphs.get_ready_to_mark_running_embeds():
            log.debug(f"BatchExecutor: marking action {full_id} as ready")
            await self._update_task(full_id, new_status=TaskStatus.RUNNING)

        for full_id in self._graphs.get_ready_to_mark_done_embeds():
            log.debug(f"BatchExecutor: marking action {full_id} as done")
            # done action, make it finished
            ctx = await self._get_action(full_id)
            results = self._tasks_mgr.build_needs(full_id, ctx.graph.keys())
            res_ctx = await ctx.calc_outputs(results)

            task = await self._update_task(
                full_id,
                new_status=TaskStatus(res_ctx.result),
                outputs=res_ctx.outputs,
                state={},
            )

            task_id = self._bars[full_id]
            self._progress.remove_task(task_id)

            task_meta = await self._get_meta(full_id)
            if task.status != TaskStatus.SUCCEEDED and task_meta.strategy.fail_fast:
                return False
        return True

    def _accumulate_result(self) -> TaskStatus:
        for task in self._tasks_mgr.tasks.values():
            if task.status == TaskStatus.CANCELLED:
                return TaskStatus.CANCELLED
            elif task.status == TaskStatus.FAILED:
                return TaskStatus.FAILED

        return TaskStatus.SUCCEEDED

    async def _is_image_in_registry(self, remote_image: RemoteImage) -> bool:
        try:
            await self._client.image_tag_info(remote_image)
        except ResourceNotFound:
            return False
        return True

    async def _start_task(self, full_id: FullID,
                          task: Task) -> Optional[StorageTask]:
        log.debug(
            f"BatchExecutor: checking should we build image for {full_id}")
        remote_image = self._client.parse_remote_image(task.image)
        log.debug(f"BatchExecutor: image name is {remote_image}")
        if remote_image.cluster_name is None:  # Not a neuro registry image
            return await self._run_task(full_id, task)

        image_storage = self._bake_storage.bake_image(ref=task.image)
        try:
            bake_image = await image_storage.get()
        except ResourceNotFound:
            # Not defined in the bake
            return await self._run_task(full_id, task)

        image_ctx = await self._get_image(bake_image)
        if image_ctx is None:
            # The image referred before definition
            return await self._run_task(full_id, task)

        if not image_ctx.force_rebuild and await self._is_image_in_registry(
                remote_image):
            if bake_image.status == ImageStatus.PENDING:
                await image_storage.update(status=ImageStatus.CACHED, )
            return await self._run_task(full_id, task)

        if bake_image.status == ImageStatus.PENDING:
            await self._start_image_build(bake_image, image_ctx, image_storage)
            return None  # wait for next pulling interval
        elif bake_image.status == ImageStatus.BUILDING:
            return None  # wait for next pulling interval
        else:
            # Image is already build (maybe with an error, just try to run a job)
            return await self._run_task(full_id, task)

    async def _run_task(self, full_id: FullID, task: Task) -> StorageTask:
        log.debug(f"BatchExecutor: starting job for {full_id}")
        preset_name = task.preset
        if preset_name is None:
            preset_name = next(iter(self._client.config_presets))

        envs = self._client.parse_envs(
            [f"{k}={v}" for k, v in task.env.items()])

        volumes_parsed = self._client.parse_volumes(task.volumes)
        volumes = list(volumes_parsed.volumes)

        http_auth = task.http_auth
        if http_auth is None:
            http_auth = HTTPPort.requires_auth

        job = await self._client.job_start(
            shm=True,
            tty=False,
            image=self._client.parse_remote_image(task.image),
            preset_name=preset_name,
            entrypoint=task.entrypoint,
            command=task.cmd,
            http=HTTPPort(task.http_port, http_auth)
            if task.http_port else None,
            env=envs.env,
            secret_env=envs.secret_env,
            volumes=volumes,
            working_dir=str(task.workdir) if task.workdir else None,
            secret_files=list(volumes_parsed.secret_files),
            disk_volumes=list(volumes_parsed.disk_volumes),
            name=task.name,
            tags=list(task.tags),
            description=task.title,
            life_span=task.life_span,
            schedule_timeout=task.schedule_timeout,
            pass_config=bool(task.pass_config),
        )
        await self._add_resource(job.uri)
        return await self._create_task(
            yaml_id=full_id,
            raw_id=job.id,
            status=TaskStatusItem(
                when=job.history.created_at or datetime.now(timezone.utc),
                status=TaskStatus.PENDING,
            ),
        )

    async def _start_image_build(
        self,
        bake_image: BakeImage,
        image_ctx: ImageCtx,
        image_storage: BakeImageStorage,
    ) -> None:
        log.debug(f"BatchExecutor: starting image build for {bake_image}")
        context = bake_image.context_on_storage or image_ctx.context
        dockerfile_rel = bake_image.dockerfile_rel or image_ctx.dockerfile_rel

        if context is None or dockerfile_rel is None:
            if context is None and dockerfile_rel is None:
                error = "context and dockerfile not specified"
            elif context is None and dockerfile_rel is not None:
                error = "context not specified"
            else:
                error = "dockerfile not specified"
            raise Exception(
                f"Failed to build image '{bake_image.ref}': {error}")

        cmd = ["neuro-extras", "image", "build"]
        cmd.append(f"--file={dockerfile_rel}")
        for arg in image_ctx.build_args:
            cmd.append(f"--build-arg={arg}")
        for vol in image_ctx.volumes:
            cmd.append(f"--volume={vol}")
        for k, v in image_ctx.env.items():
            cmd.append(f"--env={k}={v}")
        if image_ctx.force_rebuild:
            cmd.append("--force-overwrite")
        if image_ctx.build_preset is not None:
            cmd.append(f"--preset={image_ctx.build_preset}")
        cmd.append(str(context))
        cmd.append(str(bake_image.ref))
        builder_job_id = await self._run_builder_job(*cmd)
        await image_storage.update(
            builder_job_id=builder_job_id,
            status=ImageStatus.BUILDING,
        )
        self._progress.log("Image", fmt_id(bake_image.ref), "is",
                           ImageStatus.BUILDING)
        self._progress.log("  builder job:", fmt_raw_id(builder_job_id))

    async def _process_running_builds(self) -> None:
        log.debug(f"BatchExecutor: checking running build")
        async for image in self._bake_storage.list_bake_images():
            log.debug(f"BatchExecutor: processing image {image}")
            image_storage = self._bake_storage.bake_image(id=image.id)
            if image.status == ImageStatus.BUILDING:
                assert image.builder_job_id
                descr = await self._client.job_status(image.builder_job_id)
                log.debug(f"BatchExecutor: got job description {descr}")
                if descr.status == JobStatus.SUCCEEDED:
                    await image_storage.update(status=ImageStatus.BUILT, )
                    self._progress.log("Image", fmt_id(image.ref), "is",
                                       ImageStatus.BUILT)
                elif descr.status in TERMINATED_JOB_STATUSES:
                    await image_storage.update(
                        status=ImageStatus.BUILD_FAILED, )
                    self._progress.log("Image", fmt_id(image.ref), "is",
                                       ImageStatus.BUILD_FAILED)

    async def _create_project_role(self, project_role: str) -> None:
        if self._is_projet_role_created:
            return
        log.debug(f"BatchExecutor: creating project role {project_role}")
        try:
            await self._client.user_add(project_role)
        except AuthorizationError:
            log.debug(f"BatchExecutor: AuthorizationError for create"
                      f" project role {project_role}")
            pass
            # We have no permissions to create role --
            # assume that this is shared project and
            # current user is not the owner
        except IllegalArgumentError as e:
            if "already exists" not in str(e):
                raise
        self._is_projet_role_created = True

    async def _add_resource(self, uri: URL) -> None:
        log.debug(f"BatchExecutor: adding resource {uri}")
        project_role = self._project_role
        if project_role is None:
            return
        await self._create_project_role(project_role)
        permission = Permission(uri, Action.WRITE)
        try:
            await self._client.user_share(project_role, permission)
        except ValueError:
            self._progress.log(f"[red]Cannot share [b]{uri!s}[/b]")
Example #8
0
File: mori.py Project: zxjlm/Mori
def processor(
    site_data: dict,
    timeout: int,
    use_proxy: bool,
    result_printer: ResultPrinter,
    task_id: TaskID,
    progress: Progress,
) -> dict:
    """
    the main processor for mori.
    Args:
        result_printer:
        site_data: a dictionary in site_data_list which read from json file
        timeout: Time in seconds to wait before timing out request
        use_proxy: not use proxy while this value is False. Of course, proxy
            field in the config file should have a value.
            result_printer: when processor finish , result_printer will be
            invoked to output result.
        task_id: it is id of main_progress, when processor finish,
            main_progress while step 1.
        progress: main progress.

    Returns:
        Dictionary containing results from report.
        'name': api name in configuration,
        'url': api url in configuration,
        'base_url': api base_url in configuration,
        'resp_text': raw response.text from url, if length of resp_text > 500,
            it wont`t display on console, and you can add --xls to see detail
            in *.xls file.
        'status_code': response status_code,
        'time(s)': time in seconds spend on request,
        'error_text': error_text,
        'exception_text': exception_text,
        'check_result': 'OK' , 'Damage' or 'Unknown'. Default is 'Unknown'.
        'traceback': Instance of Traceback.
        'check_results': check_results, each result of all regexes.
        'remark': field hold on
    """
    rel_result, result, monitor_id = {}, {}, None
    session = requests.Session()
    max_retries = 5
    if progress:
        monitor_id = progress.add_task(
            f'{site_data["name"]} (retry)', visible=False, total=max_retries
        )
    # progress.update(monitor_id, advance=-max_retries)
    check_result, check_results = "Damage", []
    r, resp_text = None, ""
    try:
        for retries in range(max_retries):
            check_result = "Damage"
            traceback, r, resp_text = None, None, ""
            error_text, exception_text, check_results = "", "", {}
            check_result = "Unknown"

            headers = {
                "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.12; "
                "rv:55.0) Gecko/20100101 Firefox/55.0",
            }

            if site_data.get("headers"):
                if isinstance(site_data.get("headers"), dict):
                    headers.update(site_data.get("headers"))

            Proxy.set_proxy_url(
                site_data.get("proxy"),
                site_data.get("strict_proxy"),
                use_proxy,
                headers,
            )

            if site_data.get("antispider") and site_data.get("data"):
                try:
                    import importlib

                    package = importlib.import_module(
                        "antispider." + site_data["antispider"]
                    )
                    Antispider = getattr(package, "Antispider")
                    site_data["data"], headers = Antispider(
                        site_data["data"], headers
                    ).processor()
                except Exception as _e:
                    site_data["single"] = True
                    site_data["exception_text"] = _e
                    site_data["traceback"] = Traceback()
                    raise Exception("antispider failed")

            try:
                proxies = Proxy.get_proxy()
            except Exception as _e:
                site_data["single"] = True
                site_data["exception_text"] = _e
                site_data["traceback"] = Traceback()
                raise Exception("all of six proxies can`t be used")

            if site_data.get("single"):
                error_text = site_data["error_text"]
                exception_text = site_data["exception_text"]
                traceback = site_data["traceback"]
            else:
                (
                    r,
                    error_text,
                    exception_text,
                    check_results,
                    check_result,
                    traceback,
                    resp_text,
                ) = get_response(session, site_data, headers, timeout, proxies)
                if error_text and retries + 1 < max_retries:
                    if progress:
                        progress.update(
                            monitor_id, advance=1, visible=True, refresh=True
                        )
                    continue

            result = {
                "name": site_data["name"],
                "url": site_data["url"],
                "base_url": site_data.get("base_url", ""),
                "resp_text": resp_text
                if len(resp_text) < 500
                else "too long, and you can add --xls " "to see detail in *.xls file",
                "status_code": getattr(r, "status_code", "failed"),
                "time(s)": float(r.elapsed.total_seconds()) if r else -1.0,
                "error_text": str(error_text),
                "exception_text": exception_text,
                "check_result": check_result,
                "traceback": traceback,
                "check_results": check_results,
                "remark": site_data.get("remark", ""),
            }

            rel_result = dict(result.copy())
            rel_result["resp_text"] = resp_text
            break
    except Exception as error:
        result = {
            "name": site_data["name"],
            "url": site_data["url"],
            "base_url": site_data.get("base_url", ""),
            "resp_text": resp_text
            if len(resp_text) < 500
            else "too long, and you can add --xls " "to see detail in *.xls file",
            "status_code": getattr(r, "status_code", "failed") if r else "none",
            "time(s)": float(r.elapsed.total_seconds()) if r else -1.0,
            "error_text": str(error) or "site handler error",
            "check_result": check_result,
            "traceback": Traceback(),
            "check_results": check_results,
            "remark": site_data.get("remark", ""),
        }
        rel_result = dict(result.copy())

    if result_printer:
        progress.update(task_id, advance=1, refresh=True)
        result_printer.printer(result)
        progress.remove_task(monitor_id)

    return rel_result
Example #9
0
def fit(epochs: int,
        lr: float,
        model: nn.Module,
        loss_fn,
        opt,
        train_dl,
        val_dl,
        debug_run: bool = False,
        debug_num_batches: int = 5):
    model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    # mb = master_bar(range(epochs))
    progress_bar = Progress()

    # mb.write(['epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time'], table=True)
    progress_bar.start()
    header = [
        'epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time'
    ]  # , 'val_time', 'train_time']
    progress_bar.print(' '.join([f"{value:>9}" for value in header]))
    main_job = progress_bar.add_task('fit...', total=epochs)
    # for epoch in mb:
    for epoch in range(epochs):
        progress_bar.tasks[
            main_job].description = f"ep {epoch + 1} of {epochs}"
        # mb.main_bar.comment = f"ep {epoch + 1} of {epochs}"

        model.train()
        start_time = time.time()
        # for batch_num, batch in enumerate(progress_bar(train_dl, parent=mb)):
        len_train_dl = len(train_dl)
        train_job = progress_bar.add_task('train', total=len_train_dl)
        for batch_num, batch in enumerate(train_dl):
            progress_bar._tasks[
                train_job].description = f"batch {batch_num}/{len_train_dl}"
            if debug_run and batch_num == debug_num_batches:
                break
            loss = loss_batch(model, loss_fn, batch)
            # mb.child.comment = f"loss {loss:0.4f}"
            loss.backward()
            opt.step()
            opt.zero_grad()
            progress_bar.update(train_job, advance=1)
        train_time = time.time() - start_time
        model.eval()
        len_val_dl = len(val_dl)
        val_job = progress_bar.add_task('validate...', total=len_val_dl)
        with torch.no_grad():
            valid_loss = []
            # for batch_num, batch in enumerate(progress_bar(val_dl, parent=mb)):
            for batch_num, batch in enumerate(val_dl):
                if debug_run and batch_num == debug_num_batches:
                    break
                valid_loss.append(loss_batch(model, loss_fn, batch).item())
            valid_loss = sum(valid_loss)
            progress_bar.update(val_job, advance=1)
        epoch_time = time.time() - start_time
        # mb.write([str(epoch + 1), f'{loss:0.4f}', f'{valid_loss / len(val_dl):0.4f}',
        #          format_time(epoch_time), format_time(train_time), format_time(epoch_time - train_time)], table=True)
        to_progress_bar = [
            f"{epoch + 1}", f"{loss:0.4f}", f"{valid_loss:0.4f}",
            f"{format_time(epoch_time)}", f"{format_time(train_time)}"
        ]  # , f"{format_time(val_time)}"]
        progress_bar.print(' '.join(
            [f"{value:>9}" for value in to_progress_bar]))
        progress_bar.update(main_job, advance=1)
        progress_bar.remove_task(train_job)
        progress_bar.remove_task(val_job)

    progress_bar.remove_task(main_job)
    progress_bar.stop()
Example #10
0
def do_one_ping(
    host,
    progress: Progress,
    cc: Optional[str] = None,
    seq_offset=0,
    count=8,
    interval=0.1,
    timeout=2,
    id=PID,
    source=None,
    **kwargs,
) -> PingResult:
    """
    :raises NameLookupError: If you pass a hostname or FQDN in
        parameters and it does not exist or cannot be resolved.
    :raises SocketPermissionError: If the privileges are insufficient
        to create the socket.
    :raises SocketAddressError: If the source address cannot be
        assigned to the socket.
    :raises ICMPSocketError: If another error occurs. See the
        `ICMPv4Socket` or `ICMPv6Socket` class for details.
    """

    address = resolve(host)
    if isinstance(address, list):
        address = address[0]

    log = logging.getLogger("rich")

    task_id = progress.add_task(host, total=count)

    # on linux `privileged` must be True
    if is_ipv6_address(address):
        sock = ICMPv6Socket(address=source, privileged=True)
    else:
        sock = ICMPv4Socket(address=source, privileged=True)

    times = []

    for sequence in range(count):
        progress.update(task_id, advance=1)

        request = ICMPRequest(
            destination=address, id=id, sequence=sequence + seq_offset, **kwargs
        )

        try:
            sock.send(request)

            reply = sock.receive(request, timeout)
            reply.raise_for_status()

            round_trip_time = (reply.time - request.time) * 1000
            times.append(round_trip_time)

            if sequence < count - 1:
                sleep(interval)

        except ICMPLibError as e:
            log.error(f"接收 {host} Ping 返回信息失败: {e}")

    progress.remove_task(task_id=task_id)

    sock.close()

    log.info(f"{host} Ping 检测已经完成")

    return PingResult(host=host, cc=cc, count=count, times=times)
Example #11
0
class DeviantArtDownloader:
    def __init__(self, client_id, client_secret):
        self.api = Api(client_id, client_secret)
        self.progress = Progress(
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            DownloadColumn(),
            TransferSpeedColumn(),
            "[bold blue]{task.fields[filename]}",
        )
        self.all_t = self.progress.add_task('All', filename='All', start=0)
        self.total_length = 0;

    def download_worker(self, task_id, url, path):
        with open(path, 'wb') as f, GET(url, stream=True) as rq:
            length = int(rq.headers.get('Content-Length', 0))
            self.progress.start_task(task_id)
            self.progress.update(task_id, total=length)
            self.total_length += length
            self.progress.update(self.all_t, total=self.total_length)
            for chunk in rq.iter_content(chunk_size=4096):
                f.write(chunk)
                self.progress.update(task_id, advance=len(chunk))
                self.progress.update(self.all_t, advance=len(chunk))
        return task_id
    
    def search_content(self, tag, max_items=-1):
        n_items = 0
        offset = 0
        while True:
            data = self.api.browse('tags', tag=tag, offset=offset)
            for item in data['results']:
                yield item
                n_items += 1
                if n_items > max_items and max_items > 0:
                    return
            if not data['has_more']:
                break
            offset = data['next_offset']
            
    @staticmethod
    def _make_filename(item):
        src = item.content['src']
        ext = splitext(urlparse(src).path)[1]
        return splitpath(item.url)[1] + ext

    def download(self,
                 tag,
                 out_dir='.',
                 max_items=-1,
                 max_workers=8,
                 list_path=None):
        if not exists(out_dir):
            mkdir(out_dir)
        with self.progress, ThreadPoolExecutor(max_workers=max_workers) as pool:
            self.progress.start_task(self.all_t)
            futures = []
            for item in self.search_content(tag, max_items):
                if list_path:
                    with open(list_path, 'a') as flist:
                        flist.write(item.url + '\n')
                if not item.content:
                    continue
                filename = join(out_dir, self._make_filename(item))
                task_id = self.progress.add_task(
                        'download',
                        filename=item.title,
                        start=0)

                url = item.content['src']
                f = pool.submit(self.download_worker, task_id, url, filename)
                futures.append(f)
                while len(futures) >= max_workers:
                    for f in futures:
                        if f.done():
                            futures.remove(f)
                            self.progress.remove_task(f.result())
                    sleep(0.1)
Example #12
0
File: run.py Project: efonte/ESRGAN
def video_thread_func(
    device: torch.device,
    num_lock: int,
    multi_gpu: bool,
    input: Path,
    start_frame: int,
    end_frame: int,
    num_frames: int,
    progress: Progress,
    task_upscaled_id: TaskID,
    ai_upscaled_path: Path,
    fps: int,
    quality: float,
    ffmpeg_params: str,
    deinterpaint: DeinterpaintOptions,
    diff_mode: bool,
    ssim: bool,
    min_ssim: float,
    chunk_size: int,
    padding_size: int,
    scale: int,
    upscale: Upscale,
    config: configparser.ConfigParser,
    scenes_ini: Path,
):
    log = logging.getLogger()
    video_reader: FfmpegFormat.Reader = imageio.get_reader(
        str(input.absolute()))
    start_time = time.process_time()
    last_frame = None
    last_frame_ai = None
    current_frame = None
    frames_diff: List[Optional[FrameDiff]] = []
    video_reader.set_image_index(start_frame - 1)
    start_frame_str = str(start_frame).zfill(len(str(num_frames)))
    end_frame_str = str(end_frame).zfill(len(str(num_frames)))
    task_scene_desc = f'Scene [green]"{start_frame_str}_{end_frame_str}"[/]'
    if multi_gpu and len(upscale.devices) > 1:
        if device.type == "cuda":
            device_name = torch.cuda.get_device_name(device.index)
        else:
            device_name = "CPU"
        task_scene_desc += f" ({device_name})"
    task_scene_id = progress.add_task(
        description=task_scene_desc,
        total=end_frame - start_frame + 1,
        completed=0,
        refresh=True,
    )
    video_writer_params = {"quality": quality, "macro_block_size": None}
    if ffmpeg_params:
        if "-crf" in ffmpeg_params:
            del video_writer_params["quality"]
        video_writer_params["output_params"] = ffmpeg_params.split()
    video_writer: FfmpegFormat.Writer = imageio.get_writer(
        str(
            ai_upscaled_path.joinpath(
                f"{start_frame_str}_{end_frame_str}.mp4").absolute()),
        fps=fps,
        **video_writer_params,
    )
    duplicated_frames = 0
    total_duplicated_frames = 0
    for current_frame_idx in range(start_frame, end_frame + 1):
        frame = video_reader.get_next_data()

        if deinterpaint is not None:
            for i in range(
                    0 if deinterpaint == DeinterpaintOptions.even else 1,
                    frame.shape[0], 2):
                frame[i:i + 1] = (0, 255, 0)  # (B, G, R)

        if not diff_mode:
            if last_frame is not None and are_same_imgs(
                    last_frame, frame, ssim, min_ssim):
                frame_ai = last_frame_ai
                if duplicated_frames == 0:
                    start_duplicated_frame = current_frame_idx - 1
                duplicated_frames += 1
            else:
                frame_ai = upscale.image(frame,
                                         device,
                                         multi_gpu_release_device=False)
                if duplicated_frames != 0:
                    start_duplicated_frame_str = str(
                        start_duplicated_frame).zfill(len(str(num_frames)))
                    current_frame_idx_str = str(current_frame_idx - 1).zfill(
                        len(str(num_frames)))
                    log.info(
                        f"Detected {duplicated_frames} duplicated frame{'' if duplicated_frames==1 else 's'} ({start_duplicated_frame_str}-{current_frame_idx_str})"
                    )
                    total_duplicated_frames += duplicated_frames
                    duplicated_frames = 0
            video_writer.append_data(frame_ai)
            last_frame = frame
            last_frame_ai = frame_ai
            progress.advance(task_upscaled_id)
            progress.advance(task_scene_id)
        else:
            if current_frame is None:
                current_frame = frame
            else:
                frame_diff = get_diff_frame(current_frame, frame, chunk_size,
                                            padding_size, ssim, min_ssim)
                if (
                        frame_diff is None
                ):  # the frame is equal to current_frame, the best scenario!!!
                    frames_diff.append(frame_diff)
                else:
                    h_diff, w_diff, c_diff = frame_diff.frame.shape
                    h, w, c = current_frame.shape
                    if w * h > w_diff * h_diff:  # TODO difference of size > 20%
                        frames_diff.append(frame_diff)
                    else:
                        current_frame_ai = upscale.image(
                            current_frame,
                            device,
                            multi_gpu_release_device=False)
                        video_writer.append_data(current_frame_ai)
                        progress.advance(task_upscaled_id)
                        progress.advance(task_scene_id)
                        current_frame = frame
                        for frame_diff in frames_diff:
                            if frame_diff is None:
                                frame_ai = current_frame_ai
                            else:
                                diff_ai = upscale.image(
                                    frame_diff.frame,
                                    device,
                                    multi_gpu_release_device=False,
                                )
                                frame_diff_ai = frame_diff
                                frame_diff_ai.frame = diff_ai
                                frame_ai = get_frame(
                                    current_frame_ai,
                                    frame_diff_ai,
                                    scale,
                                    chunk_size,
                                    padding_size,
                                )
                            video_writer.append_data(frame_ai)
                            progress.advance(task_upscaled_id)
                            progress.advance(task_scene_id)
                        frames_diff = []
    if diff_mode:
        if len(frames_diff) > 0:
            current_frame_ai = upscale.image(current_frame,
                                             device,
                                             multi_gpu_release_device=False)
            video_writer.append_data(current_frame_ai)
            progress.advance(task_upscaled_id)
            progress.advance(task_scene_id)
            for frame_diff in frames_diff:
                if frame_diff is None:
                    frame_ai = current_frame
                else:
                    diff_ai = upscale.image(frame_diff.frame,
                                            device,
                                            multi_gpu_release_device=False)
                    frame_diff_ai = frame_diff
                    frame_diff_ai.frame = diff_ai
                    frame_ai = get_frame(
                        current_frame_ai,
                        frame_diff_ai,
                        scale,
                        chunk_size,
                        padding_size,
                    )
                video_writer.append_data(frame_ai)
                progress.advance(task_upscaled_id)
                progress.advance(task_scene_id)
            current_frame = None
            frames_diff = []
        elif current_frame is not None:
            current_frame_ai = upscale.image(current_frame,
                                             device,
                                             multi_gpu_release_device=False)
            video_writer.append_data(current_frame_ai)
            progress.advance(task_upscaled_id)
            progress.advance(task_scene_id)
    if duplicated_frames != 0:
        start_duplicated_frame_str = str(start_duplicated_frame).zfill(
            len(str(num_frames)))
        current_frame_idx_str = str(current_frame_idx - 1).zfill(
            len(str(num_frames)))
        log.info(
            f"Detected {duplicated_frames} duplicated frame{'' if duplicated_frames==1 else 's'} ({start_duplicated_frame_str}-{current_frame_idx_str})"
        )
        total_duplicated_frames += duplicated_frames
        duplicated_frames = 0

    video_writer.close()
    task_scene = next(task for task in progress.tasks
                      if task.id == task_scene_id)

    config.set(f"{start_frame_str}_{end_frame_str}", "upscaled", "True")
    config.set(
        f"{start_frame_str}_{end_frame_str}",
        "duplicated_frames",
        f"{total_duplicated_frames}",
    )
    finished_speed = task_scene.finished_speed or task_scene.speed or 0.01
    config.set(
        f"{start_frame_str}_{end_frame_str}",
        "average_fps",
        f"{finished_speed:.2f}",
    )
    with open(scenes_ini, "w") as configfile:
        config.write(configfile)
    log.info(
        f"Frames from {str(start_frame).zfill(len(str(num_frames)))} to {str(end_frame).zfill(len(str(num_frames)))} upscaled in {precisedelta(dt.timedelta(seconds=time.process_time() - start_time))}"
    )
    if total_duplicated_frames > 0:
        total_frames = end_frame - (start_frame - 1)
        seconds_saved = (((1 / finished_speed * total_frames) - (
            total_duplicated_frames * 0.04)  # 0.04 seconds per duplicate frame
                          ) / (total_frames - total_duplicated_frames) *
                         total_duplicated_frames)
        log.info(
            f"Total number of duplicated frames from {str(start_frame).zfill(len(str(num_frames)))} to {str(end_frame).zfill(len(str(num_frames)))}: {total_duplicated_frames} (saved ≈ {precisedelta(dt.timedelta(seconds=seconds_saved))})"
        )
    progress.remove_task(task_scene_id)
    if multi_gpu:
        upscale.devices[device][num_lock].release()