コード例 #1
0
ファイル: dvclive.py プロジェクト: www516717402/mmcv
class DvcliveLoggerHook(LoggerHook):
    """Class to log metrics with dvclive.

    It requires `dvclive`_ to be installed.

    Args:
        model_file (str): Default None. If not None, after each epoch the
            model will be saved to {model_file}.
        interval (int): Logging interval (every k iterations). Default 10.
        ignore_last (bool): Ignore the log of last iterations in each epoch
            if less than `interval`. Default: True.
        reset_flag (bool): Whether to clear the output buffer after logging.
            Default: False.
        by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
        kwargs: Arguments for instantiating `Live`_.

    .. _dvclive:
        https://dvc.org/doc/dvclive

    .. _Live:
        https://dvc.org/doc/dvclive/api-reference/live#parameters
    """
    def __init__(self,
                 model_file=None,
                 interval=10,
                 ignore_last=True,
                 reset_flag=False,
                 by_epoch=True,
                 **kwargs):
        super().__init__(interval, ignore_last, reset_flag, by_epoch)
        self.model_file = model_file
        self.import_dvclive(**kwargs)

    def import_dvclive(self, **kwargs):
        try:
            from dvclive import Live
        except ImportError:
            raise ImportError(
                'Please run "pip install dvclive" to install dvclive')
        self.dvclive = Live(**kwargs)

    @master_only
    def log(self, runner):
        tags = self.get_loggable_tags(runner)
        if tags:
            self.dvclive.set_step(self.get_iter(runner))
            for k, v in tags.items():
                self.dvclive.log(k, v)

    @master_only
    def after_train_epoch(self, runner):
        super().after_train_epoch(runner)
        if self.model_file is not None:
            runner.save_checkpoint(
                Path(self.model_file).parent,
                filename_tmpl=Path(self.model_file).name,
                create_symlink=False,
            )
コード例 #2
0
def test_get_step_custom_steps(tmp_dir):
    dvclive = Live()

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("x", metric)
        assert dvclive.get_step() == step
コード例 #3
0
def test_custom_steps(tmp_dir, mocker):
    dvclive = Live("logs")

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("m", metric)

    assert read_history("logs", "m") == (steps, metrics)
    assert read_latest("logs", "m") == (last(steps), last(metrics))
コード例 #4
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history("dvclive", "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history("dvclive", "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest("dvclive", "train_m") == (2, 1)
    assert read_latest("dvclive", "val_m") == (2, 1)
コード例 #5
0
def test_custom_steps(tmp_dir):
    dvclive = Live("logs")

    out = tmp_dir / dvclive.dir / Scalar.subfolder

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("m", metric)

    assert read_history(out, "m") == (steps, metrics)
    assert read_latest(out, "m") == (last(steps), last(metrics))
コード例 #6
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    out = tmp_dir / dvclive.dir / Scalar.subfolder

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history(out, "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history(out, "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest(out, "train_m") == (2, 1)
    assert read_latest(out, "val_m") == (2, 1)
コード例 #7
0
def test_get_renderers(tmp_dir, mocker):
    live = Live()

    for i in range(2):
        live.log("foo", i)
        img = Image.new("RGB", (10, 10), (i, i, i))
        live.log_image("image.png", img)
        live.next_step()

    live.set_step(None)
    live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])

    image_renderers = get_image_renderers(tmp_dir / live.dir /
                                          LiveImage.subfolder)
    assert len(image_renderers) == 2
    image_renderers = sorted(image_renderers,
                             key=lambda x: x.datapoints[0]["rev"])
    for n, renderer in enumerate(image_renderers):
        assert renderer.datapoints == [{
            "src": mocker.ANY,
            "rev": os.path.join(str(n), "image.png")
        }]

    scalar_renderers = get_scalar_renderers(tmp_dir / live.dir /
                                            Scalar.subfolder)
    assert len(scalar_renderers) == 1
    assert scalar_renderers[0].datapoints == [
        {
            "foo": "0",
            "rev": "workspace",
            "step": "0",
            "timestamp": mocker.ANY
        },
        {
            "foo": "1",
            "rev": "workspace",
            "step": "1",
            "timestamp": mocker.ANY
        },
    ]

    plot_renderers = get_plot_renderers(tmp_dir / live.dir / Plot.subfolder)
    assert len(plot_renderers) == 1
    assert plot_renderers[0].datapoints == [
        {
            "actual": "0",
            "rev": "workspace",
            "predicted": "1"
        },
        {
            "actual": "0",
            "rev": "workspace",
            "predicted": "0"
        },
        {
            "actual": "1",
            "rev": "workspace",
            "predicted": "0"
        },
        {
            "actual": "1",
            "rev": "workspace",
            "predicted": "1"
        },
    ]
    assert plot_renderers[0].properties == ConfusionMatrix.get_properties()