Exemple #1
0
def test_file_logger_named_version(tmpdir):
    """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'"""

    exp_name = "exp"
    tmpdir.mkdir(exp_name)
    expected_version = "2020-02-05-162402"

    logger = CSVLogger(save_dir=tmpdir,
                       name=exp_name,
                       version=expected_version)
    logger.log_hyperparams({"a": 1, "b": 2})
    logger.save()
    assert logger.version == expected_version
    assert os.listdir(tmpdir / exp_name) == [expected_version]
    assert os.listdir(tmpdir / exp_name / expected_version)
def test_file_logger_log_hyperparams(tmpdir):
    logger = CSVLogger(tmpdir)
    hparams = {
        "float": 0.3,
        "int": 1,
        "string": "abc",
        "bool": True,
        "dict": {
            "a": {
                "b": "c"
            }
        },
        "list": [1, 2, 3],
        "layer": torch.nn.BatchNorm1d,
    }
    logger.log_hyperparams(hparams)
    logger.save()

    path_yaml = os.path.join(logger.log_dir,
                             ExperimentWriter.NAME_HPARAMS_FILE)
    params = load_hparams_from_yaml(path_yaml)
    assert all(n in params for n in hparams)
def test_file_logger_log_hyperparams(tmpdir):
    logger = CSVLogger(tmpdir)
    hparams = {
        "float": 0.3,
        "int": 1,
        "string": "abc",
        "bool": True,
        "dict": {
            'a': {
                'b': 'c'
            }
        },
        "list": [1, 2, 3],
        "namespace": Namespace(foo=Namespace(bar='buzz')),
        "layer": torch.nn.BatchNorm1d
    }
    logger.log_hyperparams(hparams)
    logger.save()

    path_yaml = os.path.join(logger.log_dir,
                             ExperimentWriter.NAME_HPARAMS_FILE)
    params = load_hparams_from_yaml(path_yaml)
    assert all([n in params for n in hparams])