def test_TrainConfig_build_loss(): config = TrainConfig(".", ".", ".", loss=CustomLoss(loss_variables=["x"])) # needs to be random or the normalized loss will have nan data = {"x": tf.random.uniform(shape=(4, 10))} loss = config.build_loss(data) loss_value, _ = loss(data, data) assert 0 == pytest.approx(loss_value.numpy())
def test_TrainConfig_build_model(): field = Field("out", "in") config = TrainConfig( ".", ".", ".", transformed_model=TransformedModelConfig(ArchitectureConfig("dense"), [field], 900), ) data = { field.input_name: tf.ones((1, 10)), field.output_name: tf.ones((1, 10)) } model = config.build_model(data) assert field.output_name in model(data)
def test_TrainConfig_from_args_default(): default = get_default_config() args = ["--config-path", "default"] config = TrainConfig.from_args(args=args) assert config == default
def test_TrainConfig_from_flat_dict(): d = { "train_url": "train_path", "test_url": "test_path", "out_url": "out_path", "model.architecture.name": "rnn", } config = TrainConfig.from_flat_dict(d) assert config.train_url == "train_path" assert config.model.architecture.name == "rnn" expected = get_default_config() flat_dict = _to_flat_dict(asdict(expected)) result = TrainConfig.from_flat_dict(flat_dict) assert result == expected
def test_rnn_v1_cache_disable(arch_key, expected_cache): default = get_default_config() d = asdict(default) d["cache"] = True d["model"]["architecture"]["name"] = arch_key config = TrainConfig.from_dict(d) assert config.cache == expected_cache
def test_TrainConfig_from_yaml(tmp_path): default = get_default_config() yaml_path = str(tmp_path / "train_config.yaml") with open(yaml_path, "w") as f: yaml.safe_dump(asdict(default), f) loaded = TrainConfig.from_yaml_path(yaml_path) assert loaded == default
def test_TrainConfig_defaults(): config = TrainConfig( train_url="train_path", test_url="test_path", out_url="save_path", transform=TransformConfig(), model=MicrophysicsConfig(), ) assert config # for linter
def test_TrainConfig_from_dict(): d = dict( train_url="train_path", test_url="test_path", out_url="save_path", model=dict(architecture={"name": "rnn"}), ) config = TrainConfig.from_dict(d) assert config.train_url == "train_path" assert config.model.architecture.name == "rnn"
def test_TrainConfig_asdict(): config = TrainConfig( train_url="train_path", test_url="test_path", out_url="save_path", model=MicrophysicsConfig(), ) d = asdict(config) assert d["train_url"] == "train_path" assert d["model"]["architecture"]["name"] == "linear"
def test_training_entry_integration(tmp_path): config_dict = asdict(get_default_config()) config_dict["out_url"] = str(tmp_path) config_dict["use_wandb"] = False config_dict["nfiles"] = 4 config_dict["nfiles_valid"] = 4 config_dict["epochs"] = 1 config = TrainConfig.from_dict(config_dict) main(config)
def test_TrainConfig_from_args_sysargv(monkeypatch): args = [ "unused_sysv_arg", "--config-path", "default", "--epochs", "4", "--model.architecture.name", "rnn", ] monkeypatch.setattr(sys, "argv", args) config = TrainConfig.from_args() assert config.epochs == 4 assert config.model.architecture.name == "rnn"
def test_TrainConfig_from_dict_full(): expected = get_default_config() result = TrainConfig.from_dict(asdict(expected)) assert result == expected
# add level for dataframe index, assumes equivalent feature dims sample_profile = next(iter(train_profiles.values())) train_profiles["level"] = np.arange(len(sample_profile)) test_profiles["level"] = np.arange(len(sample_profile)) log_to_table("score/train", train_scores, index=[config.wandb.job.name]) log_to_table("score/test", test_scores, index=[config.wandb.job.name]) log_to_table("profiles/train", train_profiles) log_to_table("profiles/test", test_profiles) with put_dir(config.out_url) as tmpdir: with open(os.path.join(tmpdir, "scores.json"), "w") as f: json.dump({"train": train_scores, "test": test_scores}, f) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_url", help=("Specify model path to run scoring for. Overrides use of models " "at the config.out_url"), default=None, ) known, unknown = parser.parse_known_args() config = TrainConfig.from_args(unknown) main(config, model_url=known.model_url)