예제 #1
0
파일: test_zoo.py 프로젝트: PIlotcnc/new
def test_load_model_from_recipe(recipe_args, other_args):
    recipe = Zoo.load_recipe(**recipe_args, **other_args)
    recipe_model = Zoo.load_model_from_recipe(recipe, **other_args)
    model_dict = recipe_model.dict()
    for field, value in recipe.model_metadata.dict().items():
        # TODO temporary fix while model apis need to be updated
        if field == "created" or field == "modified" or field == "release_version":
            continue
        assert model_dict[field] == value
예제 #2
0
파일: test_zoo.py 프로젝트: PIlotcnc/new
def test_load_base_model_from_recipe(recipe_args, other_args):
    recipe = Zoo.load_recipe(**recipe_args, **other_args)
    recipe_model = Zoo.load_base_model_from_recipe(recipe, **other_args)
    model_dict = recipe_model.dict()
    for field, value in recipe_args.items():
        if field == "recipe_type":
            continue
        if field == "sparse_name":
            assert model_dict[field] == "base"
        elif field == "sparse_category":
            assert model_dict[field] == "none"
        elif field == "sparse_target":
            assert model_dict[field] is None
        else:
            assert model_dict[field] == value
예제 #3
0
파일: test_zoo.py 프로젝트: PIlotcnc/new
def test_load_recipe(recipe_args, other_args):
    recipe = Zoo.load_recipe(**recipe_args, **other_args)
    recipe.download(overwrite=True)
    assert os.path.exists(recipe.path)
    shutil.rmtree(recipe.dir_path)