class DvcLiveCallback(Callback): def __init__(self, model_file=None, save_weights_only: bool = False, **kwargs): super().__init__() self.model_file = model_file self.save_weights_only = save_weights_only self.dvclive = Live(**kwargs) def on_train_begin(self, logs=None): if (self.dvclive._resume and self.model_file is not None and os.path.exists(self.model_file)): if self.save_weights_only: self.model.load_weights(self.model_file) else: self.model = load_model(self.model_file) def on_epoch_end(self, epoch: int, logs: dict = None): logs = logs or {} for metric, value in logs.items(): self.dvclive.log(metric, value) if self.model_file: if self.save_weights_only: self.model.save_weights(self.model_file) else: self.model.save(self.model_file) self.dvclive.next_step()
def import_dvclive(self, **kwargs): try: from dvclive import Live except ImportError: raise ImportError( 'Please run "pip install dvclive" to install dvclive') self.dvclive = Live(**kwargs)
def __init__(self, model_file=None, save_weights_only: bool = False, **kwargs): super().__init__() self.model_file = model_file self.save_weights_only = save_weights_only self.dvclive = Live(**kwargs)
def test_invalid_metric_type(tmp_dir, invalid_type): dvclive = Live() with pytest.raises( InvalidDataTypeError, match=f"Data 'm' has not supported type {type(invalid_type)}", ): dvclive.log("m", invalid_type)
class DvcliveLoggerHook(LoggerHook): """Class to log metrics with dvclive. It requires `dvclive`_ to be installed. Args: model_file (str): Default None. If not None, after each epoch the model will be saved to {model_file}. interval (int): Logging interval (every k iterations). Default 10. ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. Default: True. reset_flag (bool): Whether to clear the output buffer after logging. Default: False. by_epoch (bool): Whether EpochBasedRunner is used. Default: True. kwargs: Arguments for instantiating `Live`_. .. _dvclive: https://dvc.org/doc/dvclive .. _Live: https://dvc.org/doc/dvclive/api-reference/live#parameters """ def __init__(self, model_file=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True, **kwargs): super().__init__(interval, ignore_last, reset_flag, by_epoch) self.model_file = model_file self.import_dvclive(**kwargs) def import_dvclive(self, **kwargs): try: from dvclive import Live except ImportError: raise ImportError( 'Please run "pip install dvclive" to install dvclive') self.dvclive = Live(**kwargs) @master_only def log(self, runner): tags = self.get_loggable_tags(runner) if tags: self.dvclive.set_step(self.get_iter(runner)) for k, v in tags.items(): self.dvclive.log(k, v) @master_only def after_train_epoch(self, runner): super().after_train_epoch(runner) if self.model_file is not None: runner.save_checkpoint( Path(self.model_file).parent, filename_tmpl=Path(self.model_file).name, create_symlink=False, )
def test_PIL(tmp_dir): dvclive = Live() img = Image.new("RGB", (500, 500), (250, 250, 250)) dvclive.log("image.png", img) assert (tmp_dir / dvclive.dir / "image.png").exists() summary = _parse_json("dvclive.json") assert summary["image.png"] == os.path.join(dvclive.dir, "image.png")
def test_dump_kwargs(tmp_dir, y_true_y_pred_y_score, mocker): live = Live() y_true, _, y_score = y_true_y_pred_y_score spy = mocker.spy(metrics, "roc_curve") live.log_plot("roc", y_true, y_score, drop_intermediate=True) spy.assert_called_once_with(y_true, y_score, drop_intermediate=True)
def test_cleanup(tmp_dir): dvclive = Live() img = np.ones((500, 500, 3), np.uint8) dvclive.log_image("image.png", img) assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists() Live() assert not (tmp_dir / dvclive.dir / LiveImage.subfolder).exists()
def test_logging_no_step(tmp_dir): dvclive = Live("logs") dvclive.log("m1", 1) assert not (tmp_dir / "logs" / "m1.tsv").is_file() assert (tmp_dir / dvclive.summary_path).is_file() s = _parse_json(dvclive.summary_path) assert s["m1"] == 1 assert "step" not in s
def test_log_prc_curve(tmp_dir, y_true_y_pred_y_score, mocker): live = Live() out = tmp_dir / live.dir / Plot.subfolder y_true, _, y_score = y_true_y_pred_y_score spy = mocker.spy(metrics, "precision_recall_curve") live.log_plot("precision_recall", y_true, y_score) spy.assert_called_once_with(y_true, y_score) assert (out / "precision_recall.json").exists()
def test_cleanup(tmp_dir, y_true_y_pred_y_score): live = Live() out = tmp_dir / live.dir / Plot.subfolder y_true, y_pred, _ = y_true_y_pred_y_score live.log_plot("confusion_matrix", y_true, y_pred) assert (out / "confusion_matrix.json").exists() Live() assert not (tmp_dir / live.dir / Plot.subfolder).exists()
def test_log_calibration_curve(tmp_dir, y_true_y_pred_y_score, mocker): live = Live() out = tmp_dir / live.dir / Plot.subfolder y_true, _, y_score = y_true_y_pred_y_score spy = mocker.spy(calibration, "calibration_curve") live.log_plot("calibration", y_true, y_score) spy.assert_called_once_with(y_true, y_score) assert (out / "calibration.json").exists()
def test_log_reset_with_set_step(tmp_dir): dvclive = Live() for i in range(3): dvclive.set_step(i) dvclive.log("train_m", 1) for i in range(3): dvclive.set_step(i) dvclive.log("val_m", 1) assert read_history("dvclive", "train_m") == ([0, 1, 2], [1, 1, 1]) assert read_history("dvclive", "val_m") == ([0, 1, 2], [1, 1, 1]) assert read_latest("dvclive", "train_m") == (2, 1) assert read_latest("dvclive", "val_m") == (2, 1)
def test_log_confusion_matrix(tmp_dir, y_true_y_pred_y_score, mocker): live = Live() out = tmp_dir / live.dir / Plot.subfolder y_true, y_pred, _ = y_true_y_pred_y_score live.log_plot("confusion_matrix", y_true, y_pred) cm = json.loads((out / "confusion_matrix.json").read_text()) assert isinstance(cm, list) assert isinstance(cm[0], dict) assert cm[0]["actual"] == str(y_true[0]) assert cm[0]["predicted"] == str(y_pred[0])
class DvcLiveCallback(Callback): def __init__(self, model_file=None, **kwargs): super().__init__() self.model_file = model_file self.dvclive = Live(**kwargs) def after_epoch(self): for key, value in zip(self.learn.recorder.metric_names, self.learn.recorder.log): key = key.replace("_", "/") self.dvclive.log(f"{key}", float(value)) if self.model_file: self.learn.save(self.model_file) self.dvclive.next_step()
class DvcLiveCallback(TrainingCallback): def __init__(self, metric_data, model_file=None, **kwargs): super().__init__() self._metric_data = metric_data self.model_file = model_file self.dvclive = Live(**kwargs) def after_iteration(self, model, epoch, evals_log): for key, values in evals_log[self._metric_data].items(): if values: latest_metric = values[-1] self.dvclive.log(key, latest_metric) if self.model_file: model.save_model(self.model_file) self.dvclive.next_step()
def test_log_reset_with_set_step(tmp_dir): dvclive = Live() out = tmp_dir / dvclive.dir / Scalar.subfolder for i in range(3): dvclive.set_step(i) dvclive.log("train_m", 1) for i in range(3): dvclive.set_step(i) dvclive.log("val_m", 1) assert read_history(out, "train_m") == ([0, 1, 2], [1, 1, 1]) assert read_history(out, "val_m") == ([0, 1, 2], [1, 1, 1]) assert read_latest(out, "train_m") == (2, 1) assert read_latest(out, "val_m") == (2, 1)
class DvcLiveCallback: def __init__(self, model_file=None, **kwargs): super().__init__() self.model_file = model_file self.dvclive = Live(**kwargs) def __call__(self, env): for eval_result in env.evaluation_result_list: metric = eval_result[1] value = eval_result[2] self.dvclive.log(metric, value) if self.model_file: env.model.save_model(self.model_file) self.dvclive.next_step()
def test_init_from_env(tmp_dir, html, monkeypatch): monkeypatch.setenv(env.DVCLIVE_PATH, "logs") monkeypatch.setenv(env.DVCLIVE_HTML, str(int(html))) dvclive = Live() assert dvclive._path == "logs" assert dvclive._report == ("html" if html else None)
def test_get_step_control_flow(tmp_dir): dvclive = Live() while dvclive.get_step() < 10: dvclive.log("i", dvclive.get_step()) dvclive.next_step() steps, values = read_history("dvclive", "i") assert steps == list(range(10)) assert values == [float(x) for x in range(10)]
def test_init_from_env(tmp_dir, summary, html, monkeypatch): monkeypatch.setenv(env.DVCLIVE_PATH, "logs") monkeypatch.setenv(env.DVCLIVE_SUMMARY, str(int(summary))) monkeypatch.setenv(env.DVCLIVE_HTML, str(int(html))) dvclive = Live() assert dvclive._path == "logs" assert dvclive._summary == summary assert dvclive._html == html
def test_get_step_control_flow(tmp_dir): dvclive = Live() out = tmp_dir / dvclive.dir / Scalar.subfolder while dvclive.get_step() < 10: dvclive.log("i", dvclive.get_step()) dvclive.next_step() steps, values = read_history(out, "i") assert steps == list(range(10)) assert values == [float(x) for x in range(10)]
class DvcLiveCallback(TrainerCallback): def __init__(self, model_file=None, **kwargs): super().__init__() self.model_file = model_file self.dvclive = Live(**kwargs) def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): logs = kwargs["logs"] for key, value in logs.items(): self.dvclive.log(key, value) self.dvclive.next_step() def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): if self.model_file: model = kwargs["model"] model.save_pretrained(self.model_file) tokenizer = kwargs["tokenizer"] tokenizer.save_pretrained(self.model_file)
class DvcLiveCallback(Callback): def __init__(self, model_file=None, **kwargs): super().__init__(order=CallbackOrder.external) self.dvclive = Live(**kwargs) self.model_file = model_file def on_epoch_end(self, runner) -> None: for loader_key, per_loader_metrics in runner.epoch_metrics.items(): for key, value in per_loader_metrics.items(): key = key.replace("/", "_") self.dvclive.log(f"{loader_key}/{key}", float(value)) if self.model_file: checkpoint = runner.engine.pack_checkpoint( model=runner.model, criterion=runner.criterion, optimizer=runner.optimizer, scheduler=runner.scheduler, ) runner.engine.save_checkpoint(checkpoint, self.model_file) self.dvclive.next_step()
def test_get_step_custom_steps(tmp_dir): dvclive = Live() steps = [0, 62, 1000] metrics = [0.9, 0.8, 0.7] for step, metric in zip(steps, metrics): dvclive.set_step(step) dvclive.log("x", metric) assert dvclive.get_step() == step
def experiment(self): r""" Actual DVCLive object. To use DVCLive features in your :class:`~LightningModule` do the following. Example:: self.logger.experiment.some_dvclive_function() """ if self._experiment is not None: return self._experiment else: assert (rank_zero_only.rank == 0 ), "tried to init log dirs in non global_rank=0" self._experiment = Live(**self._dvclive_init) return self._experiment
def test_step_formatting(tmp_dir): dvclive = Live() img = np.ones((500, 500, 3), np.uint8) for _ in range(3): dvclive.log_image("image.png", img) dvclive.next_step() for step in range(3): assert (tmp_dir / dvclive.dir / LiveImage.subfolder / str(step) / "image.png").exists()
def test_require_step_update(tmp_dir, metric): dvclive = Live("logs") dvclive.log(metric, 1.0) with pytest.raises( DataAlreadyLoggedError, match="has already being logged whith step 'None'", ): dvclive.log(metric, 2.0)
def test_step_exception(tmp_dir, y_true_y_pred_y_score): live = Live() out = tmp_dir / live.dir / Plot.subfolder y_true, y_pred, _ = y_true_y_pred_y_score live.log_plot("confusion_matrix", y_true, y_pred) assert (out / "confusion_matrix.json").exists() with pytest.raises(NotImplementedError): live.next_step()
def test_logging_step(tmp_dir, path): dvclive = Live(path) dvclive.log("m1", 1) dvclive.next_step() assert (tmp_dir / path).is_dir() assert (tmp_dir / path / "m1.tsv").is_file() assert (tmp_dir / dvclive.summary_path).is_file() s = _parse_json(dvclive.summary_path) assert s["m1"] == 1 assert s["step"] == 0