Пример #1
0
class DvcLiveCallback(Callback):
    def __init__(self,
                 model_file=None,
                 save_weights_only: bool = False,
                 **kwargs):
        super().__init__()
        self.model_file = model_file
        self.save_weights_only = save_weights_only
        self.dvclive = Live(**kwargs)

    def on_train_begin(self, logs=None):
        if (self.dvclive._resume and self.model_file is not None
                and os.path.exists(self.model_file)):
            if self.save_weights_only:
                self.model.load_weights(self.model_file)
            else:
                self.model = load_model(self.model_file)

    def on_epoch_end(self, epoch: int, logs: dict = None):
        logs = logs or {}
        for metric, value in logs.items():
            self.dvclive.log(metric, value)
        if self.model_file:
            if self.save_weights_only:
                self.model.save_weights(self.model_file)
            else:
                self.model.save(self.model_file)
        self.dvclive.next_step()
Пример #2
0
def test_invalid_metric_type(tmp_dir, invalid_type):
    dvclive = Live()

    with pytest.raises(
            InvalidDataTypeError,
            match=f"Data 'm' has not supported type {type(invalid_type)}",
    ):
        dvclive.log("m", invalid_type)
Пример #3
0
class DvcliveLoggerHook(LoggerHook):
    """Class to log metrics with dvclive.

    It requires `dvclive`_ to be installed.

    Args:
        model_file (str): Default None. If not None, after each epoch the
            model will be saved to {model_file}.
        interval (int): Logging interval (every k iterations). Default 10.
        ignore_last (bool): Ignore the log of last iterations in each epoch
            if less than `interval`. Default: True.
        reset_flag (bool): Whether to clear the output buffer after logging.
            Default: False.
        by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
        kwargs: Arguments for instantiating `Live`_.

    .. _dvclive:
        https://dvc.org/doc/dvclive

    .. _Live:
        https://dvc.org/doc/dvclive/api-reference/live#parameters
    """
    def __init__(self,
                 model_file=None,
                 interval=10,
                 ignore_last=True,
                 reset_flag=False,
                 by_epoch=True,
                 **kwargs):
        super().__init__(interval, ignore_last, reset_flag, by_epoch)
        self.model_file = model_file
        self.import_dvclive(**kwargs)

    def import_dvclive(self, **kwargs):
        try:
            from dvclive import Live
        except ImportError:
            raise ImportError(
                'Please run "pip install dvclive" to install dvclive')
        self.dvclive = Live(**kwargs)

    @master_only
    def log(self, runner):
        tags = self.get_loggable_tags(runner)
        if tags:
            self.dvclive.set_step(self.get_iter(runner))
            for k, v in tags.items():
                self.dvclive.log(k, v)

    @master_only
    def after_train_epoch(self, runner):
        super().after_train_epoch(runner)
        if self.model_file is not None:
            runner.save_checkpoint(
                Path(self.model_file).parent,
                filename_tmpl=Path(self.model_file).name,
                create_symlink=False,
            )
Пример #4
0
def test_PIL(tmp_dir):
    dvclive = Live()
    img = Image.new("RGB", (500, 500), (250, 250, 250))
    dvclive.log("image.png", img)

    assert (tmp_dir / dvclive.dir / "image.png").exists()
    summary = _parse_json("dvclive.json")

    assert summary["image.png"] == os.path.join(dvclive.dir, "image.png")
Пример #5
0
def test_get_step_custom_steps(tmp_dir):
    dvclive = Live()

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("x", metric)
        assert dvclive.get_step() == step
Пример #6
0
def test_get_step_control_flow(tmp_dir):
    dvclive = Live()

    while dvclive.get_step() < 10:
        dvclive.log("i", dvclive.get_step())
        dvclive.next_step()

    steps, values = read_history("dvclive", "i")
    assert steps == list(range(10))
    assert values == [float(x) for x in range(10)]
Пример #7
0
def test_require_step_update(tmp_dir, metric):
    dvclive = Live("logs")

    dvclive.log(metric, 1.0)

    with pytest.raises(
            DataAlreadyLoggedError,
            match="has already being logged whith step 'None'",
    ):
        dvclive.log(metric, 2.0)
Пример #8
0
def test_logging_step(tmp_dir, path):
    dvclive = Live(path)
    dvclive.log("m1", 1)
    dvclive.next_step()
    assert (tmp_dir / path).is_dir()
    assert (tmp_dir / path / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file()

    s = _parse_json(dvclive.summary_path)
    assert s["m1"] == 1
    assert s["step"] == 0
Пример #9
0
def test_logging_no_step(tmp_dir):
    dvclive = Live("logs")

    dvclive.log("m1", 1)

    assert not (tmp_dir / "logs" / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file()

    s = _parse_json(dvclive.summary_path)
    assert s["m1"] == 1
    assert "step" not in s
Пример #10
0
def test_custom_steps(tmp_dir, mocker):
    dvclive = Live("logs")

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("m", metric)

    assert read_history("logs", "m") == (steps, metrics)
    assert read_latest("logs", "m") == (last(steps), last(metrics))
Пример #11
0
def test_get_step_control_flow(tmp_dir):
    dvclive = Live()

    out = tmp_dir / dvclive.dir / Scalar.subfolder

    while dvclive.get_step() < 10:
        dvclive.log("i", dvclive.get_step())
        dvclive.next_step()

    steps, values = read_history(out, "i")
    assert steps == list(range(10))
    assert values == [float(x) for x in range(10)]
Пример #12
0
def test_html(tmp_dir, dvc_repo, html, signal_exists, monkeypatch):
    if dvc_repo:
        from dvc.repo import Repo

        Repo.init(no_scm=True)

    monkeypatch.setenv(env.DVCLIVE_PATH, "logs")
    monkeypatch.setenv(env.DVCLIVE_HTML, str(int(html)))

    dvclive = Live()
    dvclive.log("m1", 1)
    dvclive.next_step()

    assert (tmp_dir / ".dvc" / "tmp" / SIGNAL_FILE).is_file() == signal_exists
Пример #13
0
def test_custom_steps(tmp_dir):
    dvclive = Live("logs")

    out = tmp_dir / dvclive.dir / Scalar.subfolder

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("m", metric)

    assert read_history(out, "m") == (steps, metrics)
    assert read_latest(out, "m") == (last(steps), last(metrics))
Пример #14
0
def test_get_step_resume(tmp_dir):
    dvclive = Live()

    for metric in [0.9, 0.8]:
        dvclive.log("metric", metric)
        dvclive.next_step()

    assert dvclive.get_step() == 2

    dvclive = Live(resume=True)
    assert dvclive.get_step() == 2

    dvclive = Live(resume=False)
    assert dvclive.get_step() == 0
Пример #15
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history("dvclive", "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history("dvclive", "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest("dvclive", "train_m") == (2, 1)
    assert read_latest("dvclive", "val_m") == (2, 1)
Пример #16
0
def test_step_formatting(tmp_dir):
    dvclive = Live()
    img = np.ones((500, 500, 3), np.uint8)
    for _ in range(3):
        dvclive.log("image.png", img)
        dvclive.next_step()

    for step in range(3):
        assert (tmp_dir / dvclive.dir / str(step) / "image.png").exists()

    summary = _parse_json("dvclive.json")

    assert summary["image.png"] == os.path.join(
        dvclive.dir, str(step), "image.png"
    )
Пример #17
0
class DvcLiveCallback(Callback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def after_epoch(self):
        for key, value in zip(self.learn.recorder.metric_names,
                              self.learn.recorder.log):
            key = key.replace("_", "/")
            self.dvclive.log(f"{key}", float(value))

        if self.model_file:
            self.learn.save(self.model_file)
        self.dvclive.next_step()
Пример #18
0
class DvcLiveCallback(TrainingCallback):
    def __init__(self, metric_data, model_file=None, **kwargs):
        super().__init__()
        self._metric_data = metric_data
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def after_iteration(self, model, epoch, evals_log):
        for key, values in evals_log[self._metric_data].items():
            if values:
                latest_metric = values[-1]
            self.dvclive.log(key, latest_metric)
        if self.model_file:
            model.save_model(self.model_file)
        self.dvclive.next_step()
Пример #19
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    out = tmp_dir / dvclive.dir / Scalar.subfolder

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history(out, "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history(out, "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest(out, "train_m") == (2, 1)
    assert read_latest(out, "val_m") == (2, 1)
Пример #20
0
def test_step_rename(tmp_dir, mocker):
    from pathlib import Path

    rename = mocker.spy(Path, "rename")
    dvclive = Live()
    img = np.ones((500, 500, 3), np.uint8)
    dvclive.log("image.png", img)
    assert (tmp_dir / dvclive.dir / "image.png").exists()

    dvclive.next_step()

    assert not (tmp_dir / dvclive.dir / "image.png").exists()
    assert (tmp_dir / dvclive.dir / "0" / "image.png").exists()
    rename.assert_called_once_with(
        Path(dvclive.dir) / "image.png", Path(dvclive.dir) / "0" / "image.png"
    )
Пример #21
0
class DvcLiveCallback:
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def __call__(self, env):
        for eval_result in env.evaluation_result_list:
            metric = eval_result[1]
            value = eval_result[2]
            self.dvclive.log(metric, value)

        if self.model_file:
            env.model.save_model(self.model_file)

        self.dvclive.next_step()
Пример #22
0
def test_continue(tmp_dir, resume, steps, metrics):
    dvclive = Live("logs")

    for metric in [0.9, 0.8]:
        dvclive.log("metric", metric)
        dvclive.next_step()

    assert read_history("logs", "metric") == ([0, 1], [0.9, 0.8])
    assert read_latest("logs", "metric") == (1, 0.8)

    dvclive = Live("logs", resume=resume)

    for new_metric in [0.7, 0.6]:
        dvclive.log("metric", new_metric)
        dvclive.next_step()

    assert read_history("logs", "metric") == (steps, metrics)
    assert read_latest("logs", "metric") == (last(steps), last(metrics))
Пример #23
0
def test_nested_logging(tmp_dir):
    dvclive = Live("logs", summary=True)

    dvclive.log("train/m1", 1)
    dvclive.log("val/val_1/m1", 1)
    dvclive.log("val/val_1/m2", 1)

    dvclive.next_step()

    assert (tmp_dir / "logs" / "val" / "val_1").is_dir()
    assert (tmp_dir / "logs" / "train" / "m1.tsv").is_file()
    assert (tmp_dir / "logs" / "val" / "val_1" / "m1.tsv").is_file()
    assert (tmp_dir / "logs" / "val" / "val_1" / "m2.tsv").is_file()

    summary = _parse_json(dvclive.summary_path)

    assert summary["train"]["m1"] == 1
    assert summary["val"]["val_1"]["m1"] == 1
    assert summary["val"]["val_1"]["m2"] == 1
Пример #24
0
class DvcLiveCallback(TrainerCallback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def on_log(self, args: TrainingArguments, state: TrainerState,
               control: TrainerControl, **kwargs):
        logs = kwargs["logs"]
        for key, value in logs.items():
            self.dvclive.log(key, value)
        self.dvclive.next_step()

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
                     control: TrainerControl, **kwargs):
        if self.model_file:
            model = kwargs["model"]
            model.save_pretrained(self.model_file)
            tokenizer = kwargs["tokenizer"]
            tokenizer.save_pretrained(self.model_file)
Пример #25
0
def test_nested_logging(tmp_dir):
    dvclive = Live("logs")

    out = tmp_dir / dvclive.dir / Scalar.subfolder

    dvclive.log("train/m1", 1)
    dvclive.log("val/val_1/m1", 1)
    dvclive.log("val/val_1/m2", 1)

    dvclive.next_step()

    assert (out / "val" / "val_1").is_dir()
    assert (out / "train" / "m1.tsv").is_file()
    assert (out / "val" / "val_1" / "m1.tsv").is_file()
    assert (out / "val" / "val_1" / "m2.tsv").is_file()

    summary = _parse_json(dvclive.summary_path)

    assert summary["train"]["m1"] == 1
    assert summary["val"]["val_1"]["m1"] == 1
    assert summary["val"]["val_1"]["m2"] == 1
Пример #26
0
class DvcLiveCallback(Callback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__(order=CallbackOrder.external)
        self.dvclive = Live(**kwargs)
        self.model_file = model_file

    def on_epoch_end(self, runner) -> None:
        for loader_key, per_loader_metrics in runner.epoch_metrics.items():
            for key, value in per_loader_metrics.items():
                key = key.replace("/", "_")
                self.dvclive.log(f"{loader_key}/{key}", float(value))

        if self.model_file:
            checkpoint = runner.engine.pack_checkpoint(
                model=runner.model,
                criterion=runner.criterion,
                optimizer=runner.optimizer,
                scheduler=runner.scheduler,
            )
            runner.engine.save_checkpoint(checkpoint, self.model_file)
        self.dvclive.next_step()
Пример #27
0
def test_cleanup(tmp_dir, summary, html):
    dvclive = Live("logs", summary=summary)
    dvclive.log("m1", 1)
    dvclive.next_step()

    html_path = tmp_dir / dvclive.html_path
    if html:
        html_path.touch()

    (tmp_dir / "logs" / "some_user_file.txt").touch()

    assert (tmp_dir / "logs" / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file() == summary
    assert html_path.is_file() == html

    dvclive = Live("logs", summary=summary)

    assert (tmp_dir / "logs" / "some_user_file.txt").is_file()
    assert not (tmp_dir / "logs" / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file() == summary
    assert not (html_path).is_file()
Пример #28
0
def test_cleanup(tmp_dir, html):
    dvclive = Live("logs", report="html" if html else None)
    dvclive.log("m1", 1)
    dvclive.next_step()

    html_path = tmp_dir / dvclive.html_path
    if html:
        html_path.touch()

    (tmp_dir / "logs" / "some_user_file.txt").touch()

    assert (tmp_dir / dvclive.dir / Scalar.subfolder / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file()
    assert html_path.is_file() == html

    dvclive = Live("logs")

    assert (tmp_dir / "logs" / "some_user_file.txt").is_file()
    assert not (tmp_dir / dvclive.dir / Scalar.subfolder).exists()
    assert not (tmp_dir / dvclive.summary_path).is_file()
    assert not (html_path).is_file()
Пример #29
0
def test_make_report_open(tmp_dir, mocker):
    mocked_open = mocker.patch("webbrowser.open")
    live = Live()
    live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])
    live.make_report()
    live.make_report()

    mocked_open.assert_called_once()

    mocked_open = mocker.patch("webbrowser.open")
    live = Live(auto_open=False)
    live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])
    live.make_report()

    assert not mocked_open.called

    mocked_open = mocker.patch("webbrowser.open")
    live = Live(report=None)
    live.log("foo", 1)
    live.next_step()

    assert not mocked_open.called
Пример #30
0
def test_get_renderers(tmp_dir, mocker):
    live = Live()

    for i in range(2):
        live.log("foo", i)
        img = Image.new("RGB", (10, 10), (i, i, i))
        live.log_image("image.png", img)
        live.next_step()

    live.set_step(None)
    live.log_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1])

    image_renderers = get_image_renderers(tmp_dir / live.dir /
                                          LiveImage.subfolder)
    assert len(image_renderers) == 2
    image_renderers = sorted(image_renderers,
                             key=lambda x: x.datapoints[0]["rev"])
    for n, renderer in enumerate(image_renderers):
        assert renderer.datapoints == [{
            "src": mocker.ANY,
            "rev": os.path.join(str(n), "image.png")
        }]

    scalar_renderers = get_scalar_renderers(tmp_dir / live.dir /
                                            Scalar.subfolder)
    assert len(scalar_renderers) == 1
    assert scalar_renderers[0].datapoints == [
        {
            "foo": "0",
            "rev": "workspace",
            "step": "0",
            "timestamp": mocker.ANY
        },
        {
            "foo": "1",
            "rev": "workspace",
            "step": "1",
            "timestamp": mocker.ANY
        },
    ]

    plot_renderers = get_plot_renderers(tmp_dir / live.dir / Plot.subfolder)
    assert len(plot_renderers) == 1
    assert plot_renderers[0].datapoints == [
        {
            "actual": "0",
            "rev": "workspace",
            "predicted": "1"
        },
        {
            "actual": "0",
            "rev": "workspace",
            "predicted": "0"
        },
        {
            "actual": "1",
            "rev": "workspace",
            "predicted": "0"
        },
        {
            "actual": "1",
            "rev": "workspace",
            "predicted": "1"
        },
    ]
    assert plot_renderers[0].properties == ConfusionMatrix.get_properties()