Esempio n. 1
0
def test_mmcv_hook(tmp_dir, mocker):
    work_dir = tmp_dir / "work_dir"
    runner = _build_demo_runner(str(work_dir))

    log_config = dict(
        interval=1,
        hooks=[
            dict(type="TextLoggerHook"),
            dict(type="DvcliveLoggerHook", model_file=tmp_dir / "model.pth"),
        ],
    )
    runner.register_logger_hooks(log_config)

    set_step = mocker.spy(dvclive.Live, "set_step")
    log = mocker.spy(dvclive.Live, "log")
    loader = torch.utils.data.DataLoader(torch.ones((5, 2)))

    runner.run([loader, loader], [("train", 1), ("val", 1)])

    assert set_step.call_count == 6
    assert log.call_count == 12

    logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
    assert "learning_rate" in logs
    assert "momentum" in logs
Esempio n. 2
0
def test_xgb_integration(tmp_dir, train_params, iris_data):
    xgb.train(
        train_params,
        iris_data,
        callbacks=[DvcLiveCallback("eval_data")],
        num_boost_round=5,
        evals=[(iris_data, "eval_data")],
    )

    assert os.path.exists("dvclive")

    logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)
    assert len(logs) == 1
    assert len(first(logs.values())) == 5
Esempio n. 3
0
def test_keras_callback(tmp_dir, xor_model, capture_wrap):
    model, x, y = xor_model()

    model.fit(
        x,
        y,
        epochs=1,
        batch_size=1,
        callbacks=[DvcLiveCallback()],
    )

    assert os.path.exists("dvclive")
    logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)

    assert "accuracy" in logs
Esempio n. 4
0
def test_lgbm_integration(tmp_dir, model_params, iris_data):
    model = lgbm.LGBMClassifier()
    model.set_params(**model_params)

    model.fit(
        iris_data[0][0],
        iris_data[0][1],
        eval_set=(iris_data[1][0], iris_data[1][1]),
        eval_metric=["multi_logloss"],
        callbacks=[DvcLiveCallback()],
    )

    assert os.path.exists("dvclive")

    logs, _ = read_logs("dvclive")
    assert len(logs) == 1
    assert len(first(logs.values())) == 5
Esempio n. 5
0
def test_lightning_integration(tmp_dir):
    # init model
    model = LitMNIST()
    # init logger
    dvclive_logger = DvcLiveLogger("test_run", path="logs")
    trainer = Trainer(logger=dvclive_logger,
                      max_epochs=1,
                      checkpoint_callback=False)
    trainer.fit(model)

    assert os.path.exists("logs")
    assert not os.path.exists("DvcLiveLogger")

    logs, _ = read_logs(tmp_dir / "logs" / Scalar.subfolder)

    assert len(logs) == 3
    assert "train_loss_step" in logs
    assert "train_loss_epoch" in logs
    assert "epoch" in logs
Esempio n. 6
0
def test_mmcv_hook(tmp_dir, mocker):
    work_dir = tmp_dir / "work_dir"
    runner = _build_demo_runner(str(work_dir))

    hook = DVCLiveLoggerHook()
    runner.register_hook(hook, priority="VERY_LOW")

    next_step = mocker.spy(dvclive.metrics.MetricLogger, "next_step")
    log = mocker.spy(dvclive.metrics.MetricLogger, "log")
    loader = torch.utils.data.DataLoader(torch.ones((5, 2)))

    dvclive.init("logs")

    runner.run([loader, loader], [("train", 1), ("val", 1)])

    assert next_step.call_count == 5
    assert log.call_count == 12

    logs, _ = read_logs("logs")
    assert "learning_rate" in logs
    assert "momentum" in logs
Esempio n. 7
0
def test_huggingface_integration(tmp_dir, model, args, data, tokenizer):
    trainer = Trainer(
        model,
        args,
        train_dataset=data[0],
        eval_dataset=data[1],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    trainer.add_callback(DvcLiveCallback())
    trainer.train()

    assert os.path.exists("dvclive")

    logs, _ = read_logs(tmp_dir / "dvclive" / Scalar.subfolder)

    assert len(logs) == 10
    assert "eval_matthews_correlation" in logs
    assert "eval_loss" in logs
    assert len(logs["epoch"]) == 3
    assert len(logs["eval_loss"]) == 2