Beispiel #1
0
def run(
    dataset,
    pl_model: pl.LightningModule,
    name: str,
    path: Union[Path, str],
    test_path: Union[Path, str],
    seed: int,
    args=args,
) -> None:
    seed_everything(seed, workers=True)

    datamodule: pl.LightningDataModule = DataModule(
        dataset=dataset,
        path=path,
        test_path=test_path,
        num_workers=8,
        batch_size=args.batch_size,
        seed=seed,
    )
    model: pl.LightningModule = pl_model()
    callbacks: list[Callback] = build_callbacks()
    csv_logger = CSVLogger(
        save_dir="csv_logs",
        name="seed_" + str(seed),
        version=name,
    )

    if args.fast_dev_run:
        trainer_kwargs = {"gpus": None, "auto_select_gpus": False}
    else:
        trainer_kwargs = {
            "gpus": -1,
            "auto_select_gpus": True,
            "precision": 16
        }

    trainer: pl.Trainer = pl.Trainer.from_argparse_args(
        args,
        **trainer_kwargs,
        deterministic=True,  # ensure reproducible results
        default_root_dir="ckpts",
        logger=[csv_logger],
        log_every_n_steps=10,
        callbacks=callbacks,
        max_epochs=35,
    )

    trainer.tune(model=model, datamodule=datamodule)
    trainer.fit(model=model, datamodule=datamodule)

    if not args.fast_dev_run:
        test = trainer.test(model=model,
                            ckpt_path="best",
                            datamodule=datamodule)
        pd.DataFrame(test).to_csv("csv_logs/seed_" + str(seed) + "_" + name +
                                  "_test.csv")
        csv_logger.save()

    if args.save_to_hub:
        model.model.push_to_hub(f"cjber/{args.save_to_hub}")  # type: ignore
def test_file_logger_no_name(tmpdir, name):
    """Verify that None or empty name works."""
    logger = CSVLogger(save_dir=tmpdir, name=name)
    logger.save()
    assert os.path.normpath(
        logger.root_dir) == tmpdir  # use os.path.normpath to handle trailing /
    assert os.listdir(tmpdir / "version_0")
Beispiel #3
0
def test_file_logger_named_version(tmpdir):
    """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'"""

    exp_name = "exp"
    tmpdir.mkdir(exp_name)
    expected_version = "2020-02-05-162402"

    logger = CSVLogger(save_dir=tmpdir,
                       name=exp_name,
                       version=expected_version)
    logger.log_hyperparams({"a": 1, "b": 2})
    logger.save()
    assert logger.version == expected_version
    assert os.listdir(tmpdir / exp_name) == [expected_version]
    assert os.listdir(tmpdir / exp_name / expected_version)
Beispiel #4
0
def test_file_logger_log_metrics(tmpdir, step_idx):
    logger = CSVLogger(tmpdir)
    metrics = {
        "float": 0.3,
        "int": 1,
        "FloatTensor": torch.tensor(0.1),
        "IntTensor": torch.tensor(1)
    }
    logger.log_metrics(metrics, step_idx)
    logger.save()

    path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
    with open(path_csv) as fp:
        lines = fp.readlines()
    assert len(lines) == 2
    assert all(n in lines[0] for n in metrics)
def test_file_logger_log_hyperparams(tmpdir):
    logger = CSVLogger(tmpdir)
    hparams = {
        "float": 0.3,
        "int": 1,
        "string": "abc",
        "bool": True,
        "dict": {
            "a": {
                "b": "c"
            }
        },
        "list": [1, 2, 3],
        "layer": torch.nn.BatchNorm1d,
    }
    logger.log_hyperparams(hparams)
    logger.save()

    path_yaml = os.path.join(logger.log_dir,
                             ExperimentWriter.NAME_HPARAMS_FILE)
    params = load_hparams_from_yaml(path_yaml)
    assert all(n in params for n in hparams)
def test_flush_n_steps(tmpdir):
    logger = CSVLogger(tmpdir, flush_logs_every_n_steps=2)
    metrics = {
        "float": 0.3,
        "int": 1,
        "FloatTensor": torch.tensor(0.1),
        "IntTensor": torch.tensor(1)
    }
    logger.save = MagicMock()
    logger.log_metrics(metrics, step=0)

    logger.save.assert_not_called()
    logger.log_metrics(metrics, step=1)
    logger.save.assert_called_once()
def test_file_logger_log_hyperparams(tmpdir):
    logger = CSVLogger(tmpdir)
    hparams = {
        "float": 0.3,
        "int": 1,
        "string": "abc",
        "bool": True,
        "dict": {
            'a': {
                'b': 'c'
            }
        },
        "list": [1, 2, 3],
        "namespace": Namespace(foo=Namespace(bar='buzz')),
        "layer": torch.nn.BatchNorm1d
    }
    logger.log_hyperparams(hparams)
    logger.save()

    path_yaml = os.path.join(logger.log_dir,
                             ExperimentWriter.NAME_HPARAMS_FILE)
    params = load_hparams_from_yaml(path_yaml)
    assert all([n in params for n in hparams])
Beispiel #8
0
def test_file_logger_no_name(tmpdir, name):
    """Verify that None or empty name works"""
    logger = CSVLogger(save_dir=tmpdir, name=name)
    logger.save()
    assert logger.root_dir == tmpdir
    assert os.listdir(tmpdir / "version_0")