Beispiel #1
0
class DvcLiveCallback(Callback):
    def __init__(self,
                 model_file=None,
                 save_weights_only: bool = False,
                 **kwargs):
        super().__init__()
        self.model_file = model_file
        self.save_weights_only = save_weights_only
        self.dvclive = Live(**kwargs)

    def on_train_begin(self, logs=None):
        if (self.dvclive._resume and self.model_file is not None
                and os.path.exists(self.model_file)):
            if self.save_weights_only:
                self.model.load_weights(self.model_file)
            else:
                self.model = load_model(self.model_file)

    def on_epoch_end(self, epoch: int, logs: dict = None):
        logs = logs or {}
        for metric, value in logs.items():
            self.dvclive.log(metric, value)
        if self.model_file:
            if self.save_weights_only:
                self.model.save_weights(self.model_file)
            else:
                self.model.save(self.model_file)
        self.dvclive.next_step()
Beispiel #2
0
 def import_dvclive(self, **kwargs):
     try:
         from dvclive import Live
     except ImportError:
         raise ImportError(
             'Please run "pip install dvclive" to install dvclive')
     self.dvclive = Live(**kwargs)
Beispiel #3
0
 def __init__(self,
              model_file=None,
              save_weights_only: bool = False,
              **kwargs):
     super().__init__()
     self.model_file = model_file
     self.save_weights_only = save_weights_only
     self.dvclive = Live(**kwargs)
Beispiel #4
0
def test_invalid_metric_type(tmp_dir, invalid_type):
    dvclive = Live()

    with pytest.raises(
            InvalidDataTypeError,
            match=f"Data 'm' has not supported type {type(invalid_type)}",
    ):
        dvclive.log("m", invalid_type)
Beispiel #5
0
class DvcliveLoggerHook(LoggerHook):
    """Class to log metrics with dvclive.

    It requires `dvclive`_ to be installed.

    Args:
        model_file (str): Default None. If not None, after each epoch the
            model will be saved to {model_file}.
        interval (int): Logging interval (every k iterations). Default 10.
        ignore_last (bool): Ignore the log of last iterations in each epoch
            if less than `interval`. Default: True.
        reset_flag (bool): Whether to clear the output buffer after logging.
            Default: False.
        by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
        kwargs: Arguments for instantiating `Live`_.

    .. _dvclive:
        https://dvc.org/doc/dvclive

    .. _Live:
        https://dvc.org/doc/dvclive/api-reference/live#parameters
    """
    def __init__(self,
                 model_file=None,
                 interval=10,
                 ignore_last=True,
                 reset_flag=False,
                 by_epoch=True,
                 **kwargs):
        super().__init__(interval, ignore_last, reset_flag, by_epoch)
        self.model_file = model_file
        self.import_dvclive(**kwargs)

    def import_dvclive(self, **kwargs):
        try:
            from dvclive import Live
        except ImportError:
            raise ImportError(
                'Please run "pip install dvclive" to install dvclive')
        self.dvclive = Live(**kwargs)

    @master_only
    def log(self, runner):
        tags = self.get_loggable_tags(runner)
        if tags:
            self.dvclive.set_step(self.get_iter(runner))
            for k, v in tags.items():
                self.dvclive.log(k, v)

    @master_only
    def after_train_epoch(self, runner):
        super().after_train_epoch(runner)
        if self.model_file is not None:
            runner.save_checkpoint(
                Path(self.model_file).parent,
                filename_tmpl=Path(self.model_file).name,
                create_symlink=False,
            )
Beispiel #6
0
def test_PIL(tmp_dir):
    dvclive = Live()
    img = Image.new("RGB", (500, 500), (250, 250, 250))
    dvclive.log("image.png", img)

    assert (tmp_dir / dvclive.dir / "image.png").exists()
    summary = _parse_json("dvclive.json")

    assert summary["image.png"] == os.path.join(dvclive.dir, "image.png")
def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker):
    live = Live()

    y_true, _, y_score = y_true_y_pred_y_score

    spy = mocker.spy(metrics, "roc_curve")

    live.log_plot("roc", y_true, y_score, drop_intermediate=True)

    spy.assert_called_once_with(y_true, y_score, drop_intermediate=True)
Beispiel #8
0
def test_cleanup(tmp_dir):
    dvclive = Live()
    img = np.ones((500, 500, 3), np.uint8)
    dvclive.log_image("image.png", img)

    assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists()

    Live()

    assert not (tmp_dir / dvclive.dir / LiveImage.subfolder).exists()
def test_logging_no_step(tmp_dir):
    dvclive = Live("logs")

    dvclive.log("m1", 1)

    assert not (tmp_dir / "logs" / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file()

    s = _parse_json(dvclive.summary_path)
    assert s["m1"] == 1
    assert "step" not in s
Beispiel #10
0
def test_log_prc_curve(tmp_dir, y_true_y_pred_y_score, mocker):
    live = Live()
    out = tmp_dir / live.dir / Plot.subfolder

    y_true, _, y_score = y_true_y_pred_y_score

    spy = mocker.spy(metrics, "precision_recall_curve")

    live.log_plot("precision_recall", y_true, y_score)

    spy.assert_called_once_with(y_true, y_score)
    assert (out / "precision_recall.json").exists()
Beispiel #11
0
def test_cleanup(tmp_dir, y_true_y_pred_y_score):
    live = Live()
    out = tmp_dir / live.dir / Plot.subfolder

    y_true, y_pred, _ = y_true_y_pred_y_score

    live.log_plot("confusion_matrix", y_true, y_pred)

    assert (out / "confusion_matrix.json").exists()

    Live()

    assert not (tmp_dir / live.dir / Plot.subfolder).exists()
Beispiel #12
0
def test_log_calibration_curve(tmp_dir, y_true_y_pred_y_score, mocker):
    live = Live()
    out = tmp_dir / live.dir / Plot.subfolder

    y_true, _, y_score = y_true_y_pred_y_score

    spy = mocker.spy(calibration, "calibration_curve")

    live.log_plot("calibration", y_true, y_score)

    spy.assert_called_once_with(y_true, y_score)

    assert (out / "calibration.json").exists()
Beispiel #13
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history("dvclive", "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history("dvclive", "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest("dvclive", "train_m") == (2, 1)
    assert read_latest("dvclive", "val_m") == (2, 1)
Beispiel #14
0
def test_log_confusion_matrix(tmp_dir, y_true_y_pred_y_score, mocker):
    live = Live()
    out = tmp_dir / live.dir / Plot.subfolder

    y_true, y_pred, _ = y_true_y_pred_y_score

    live.log_plot("confusion_matrix", y_true, y_pred)

    cm = json.loads((out / "confusion_matrix.json").read_text())

    assert isinstance(cm, list)
    assert isinstance(cm[0], dict)
    assert cm[0]["actual"] == str(y_true[0])
    assert cm[0]["predicted"] == str(y_pred[0])
Beispiel #15
0
class DvcLiveCallback(Callback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def after_epoch(self):
        for key, value in zip(self.learn.recorder.metric_names,
                              self.learn.recorder.log):
            key = key.replace("_", "/")
            self.dvclive.log(f"{key}", float(value))

        if self.model_file:
            self.learn.save(self.model_file)
        self.dvclive.next_step()
Beispiel #16
0
class DvcLiveCallback(TrainingCallback):
    def __init__(self, metric_data, model_file=None, **kwargs):
        super().__init__()
        self._metric_data = metric_data
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def after_iteration(self, model, epoch, evals_log):
        for key, values in evals_log[self._metric_data].items():
            if values:
                latest_metric = values[-1]
            self.dvclive.log(key, latest_metric)
        if self.model_file:
            model.save_model(self.model_file)
        self.dvclive.next_step()
Beispiel #17
0
def test_log_reset_with_set_step(tmp_dir):
    dvclive = Live()
    out = tmp_dir / dvclive.dir / Scalar.subfolder

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("train_m", 1)

    for i in range(3):
        dvclive.set_step(i)
        dvclive.log("val_m", 1)

    assert read_history(out, "train_m") == ([0, 1, 2], [1, 1, 1])
    assert read_history(out, "val_m") == ([0, 1, 2], [1, 1, 1])
    assert read_latest(out, "train_m") == (2, 1)
    assert read_latest(out, "val_m") == (2, 1)
Beispiel #18
0
class DvcLiveCallback:
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def __call__(self, env):
        for eval_result in env.evaluation_result_list:
            metric = eval_result[1]
            value = eval_result[2]
            self.dvclive.log(metric, value)

        if self.model_file:
            env.model.save_model(self.model_file)

        self.dvclive.next_step()
Beispiel #19
0
def test_init_from_env(tmp_dir, html, monkeypatch):
    monkeypatch.setenv(env.DVCLIVE_PATH, "logs")
    monkeypatch.setenv(env.DVCLIVE_HTML, str(int(html)))

    dvclive = Live()
    assert dvclive._path == "logs"
    assert dvclive._report == ("html" if html else None)
Beispiel #20
0
def test_get_step_control_flow(tmp_dir):
    dvclive = Live()

    while dvclive.get_step() < 10:
        dvclive.log("i", dvclive.get_step())
        dvclive.next_step()

    steps, values = read_history("dvclive", "i")
    assert steps == list(range(10))
    assert values == [float(x) for x in range(10)]
Beispiel #21
0
def test_init_from_env(tmp_dir, summary, html, monkeypatch):
    monkeypatch.setenv(env.DVCLIVE_PATH, "logs")
    monkeypatch.setenv(env.DVCLIVE_SUMMARY, str(int(summary)))
    monkeypatch.setenv(env.DVCLIVE_HTML, str(int(html)))

    dvclive = Live()
    assert dvclive._path == "logs"
    assert dvclive._summary == summary
    assert dvclive._html == html
Beispiel #22
0
def test_get_step_control_flow(tmp_dir):
    dvclive = Live()

    out = tmp_dir / dvclive.dir / Scalar.subfolder

    while dvclive.get_step() < 10:
        dvclive.log("i", dvclive.get_step())
        dvclive.next_step()

    steps, values = read_history(out, "i")
    assert steps == list(range(10))
    assert values == [float(x) for x in range(10)]
Beispiel #23
0
class DvcLiveCallback(TrainerCallback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__()
        self.model_file = model_file
        self.dvclive = Live(**kwargs)

    def on_log(self, args: TrainingArguments, state: TrainerState,
               control: TrainerControl, **kwargs):
        logs = kwargs["logs"]
        for key, value in logs.items():
            self.dvclive.log(key, value)
        self.dvclive.next_step()

    def on_epoch_end(self, args: TrainingArguments, state: TrainerState,
                     control: TrainerControl, **kwargs):
        if self.model_file:
            model = kwargs["model"]
            model.save_pretrained(self.model_file)
            tokenizer = kwargs["tokenizer"]
            tokenizer.save_pretrained(self.model_file)
Beispiel #24
0
class DvcLiveCallback(Callback):
    def __init__(self, model_file=None, **kwargs):
        super().__init__(order=CallbackOrder.external)
        self.dvclive = Live(**kwargs)
        self.model_file = model_file

    def on_epoch_end(self, runner) -> None:
        for loader_key, per_loader_metrics in runner.epoch_metrics.items():
            for key, value in per_loader_metrics.items():
                key = key.replace("/", "_")
                self.dvclive.log(f"{loader_key}/{key}", float(value))

        if self.model_file:
            checkpoint = runner.engine.pack_checkpoint(
                model=runner.model,
                criterion=runner.criterion,
                optimizer=runner.optimizer,
                scheduler=runner.scheduler,
            )
            runner.engine.save_checkpoint(checkpoint, self.model_file)
        self.dvclive.next_step()
Beispiel #25
0
def test_get_step_custom_steps(tmp_dir):
    dvclive = Live()

    steps = [0, 62, 1000]
    metrics = [0.9, 0.8, 0.7]

    for step, metric in zip(steps, metrics):
        dvclive.set_step(step)
        dvclive.log("x", metric)
        assert dvclive.get_step() == step
Beispiel #26
0
    def experiment(self):
        r"""
        Actual DVCLive object. To use DVCLive features in your
        :class:`~LightningModule` do the following.
        Example::
            self.logger.experiment.some_dvclive_function()
        """
        if self._experiment is not None:
            return self._experiment
        else:
            assert (rank_zero_only.rank == 0
                    ), "tried to init log dirs in non global_rank=0"
            self._experiment = Live(**self._dvclive_init)

        return self._experiment
Beispiel #27
0
def test_step_formatting(tmp_dir):
    dvclive = Live()
    img = np.ones((500, 500, 3), np.uint8)
    for _ in range(3):
        dvclive.log_image("image.png", img)
        dvclive.next_step()

    for step in range(3):
        assert (tmp_dir / dvclive.dir / LiveImage.subfolder / str(step) /
                "image.png").exists()
Beispiel #28
0
def test_require_step_update(tmp_dir, metric):
    dvclive = Live("logs")

    dvclive.log(metric, 1.0)

    with pytest.raises(
            DataAlreadyLoggedError,
            match="has already being logged whith step 'None'",
    ):
        dvclive.log(metric, 2.0)
Beispiel #29
0
def test_step_exception(tmp_dir, y_true_y_pred_y_score):
    live = Live()
    out = tmp_dir / live.dir / Plot.subfolder

    y_true, y_pred, _ = y_true_y_pred_y_score

    live.log_plot("confusion_matrix", y_true, y_pred)
    assert (out / "confusion_matrix.json").exists()

    with pytest.raises(NotImplementedError):
        live.next_step()
Beispiel #30
0
def test_logging_step(tmp_dir, path):
    dvclive = Live(path)
    dvclive.log("m1", 1)
    dvclive.next_step()
    assert (tmp_dir / path).is_dir()
    assert (tmp_dir / path / "m1.tsv").is_file()
    assert (tmp_dir / dvclive.summary_path).is_file()

    s = _parse_json(dvclive.summary_path)
    assert s["m1"] == 1
    assert s["step"] == 0