def test_labeler(tests_root): config = Config(os.path.join(tests_root, "sticker.conf")) assert config.labeler.labels.endswith("sticker.labels") config.labeler.labels = "another.labels" assert config.labeler.labels == "another.labels"
def test_model(tests_root): config = Config(os.path.join(tests_root, "sticker.conf")) assert config.model.graph.endswith("sticker.graph") assert config.model.intra_op_parallelism_threads == 2 assert config.model.inter_op_parallelism_threads == 4 assert config.model.parameters.endswith("epoch-37") config.model.graph = "another.graph" config.model.intra_op_parallelism_threads = 1 config.model.inter_op_parallelism_threads = 3 config.model.parameters = "epoch-42" assert config.model.graph == "another.graph" assert config.model.intra_op_parallelism_threads == 1 assert config.model.inter_op_parallelism_threads == 3 assert config.model.parameters == "epoch-42"
def tagger_model(tagger_model_file): config = Config(tagger_model_file) yield Tagger(config)
def test_config(tests_root): Config(os.path.join(tests_root, "sticker.conf"))
def test_missing_config(tests_root): with pytest.raises(IOError): Config(os.path.join(tests_root, "nonexistant.conf"))
def test_bogus_config(tests_root): with pytest.raises(ValueError): Config(os.path.join(tests_root, "bogus.conf"))
def pipeline_model(tagger_model_file, topo_model_file): tagger_config = Config(tagger_model_file) topo_config = Config(topo_model_file) yield Pipeline([tagger_config, topo_config])