def test_mulitple_serializations_third(keras_model): skl3 = KerasModel(artifact=keras_model) skl3.store(name='nn3') assert len(os.listdir(SAVED_MODELS)) == 3 # cleanup for root, dirs, files in os.walk(SAVED_MODELS): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d))
def test_loader(keras_model): skl = KerasModel(artifact=keras_model) skl.store(name='nn') K.clear_session() reloaded = skl.load(name='nn') assert isinstance(reloaded, KerasBaseModel) for root, dirs, files in os.walk(SAVED_MODELS): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d))
def test_mulitple_serializations_third(keras_model, project_manager): skl3 = KerasModel(artifact=keras_model) skl3.store(name='nn3') assert len(os.listdir(project_manager.CONFIG['saved-models'])) == 3 + 1 # .gitkeep # cleanup for root, dirs, files in os.walk(project_manager.CONFIG['saved-models']): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d)) with open(os.path.join(project_manager.CONFIG['saved-models'], '.gitkeep'), 'w') as gitkeep: gitkeep.write('empty')
def test_serialization(keras_model): skl = KerasModel(artifact=keras_model) skl.store(name='nn') assert os.path.exists(os.path.join(skl.model_path, 'nn' + '.h5')) assert os.path.exists(os.path.join(skl.model_path, 'nn' + '.json')) assert os.path.exists( os.path.join(skl.model_path, 'tf', 'saved_model' + '.pb')) assert os.path.isdir(os.path.join(skl.model_path, 'tf', 'variables')) for root, dirs, files in os.walk(SAVED_MODELS): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d))
def test_trainable_model_from_file(keras_model): skl = KerasModel(artifact=keras_model) skl.store(name='nn') K.clear_session() trainable = TrainableModel.from_file(run_number=1, name='nn', model_type='keras') assert isinstance(trainable.model, KerasBaseModel) for root, dirs, files in os.walk(SAVED_MODELS): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d))
def test_trainable_model_from_file(keras_model, project_manager): skl = KerasModel(artifact=keras_model) skl.store(name='nn') K.clear_session() trainable = TrainableModel.from_file(run_number=1, name='nn', model_type='keras') assert isinstance(trainable.model, KerasBaseModel) for root, dirs, files in os.walk(project_manager.CONFIG['saved-models']): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d)) with open(os.path.join(project_manager.CONFIG['saved-models'], '.gitkeep'), 'w') as gitkeep: gitkeep.write('empty')
def test_loader(keras_model, project_manager): skl = KerasModel(artifact=keras_model) skl.store(name='nn') K.clear_session() reloaded = skl.load(name='nn') assert isinstance(reloaded, KerasBaseModel) for root, dirs, files in os.walk(project_manager.CONFIG['saved-models']): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d)) with open(os.path.join(project_manager.CONFIG['saved-models'], '.gitkeep'), 'w') as gitkeep: gitkeep.write('empty')
def test_serialization(keras_model, project_manager): skl = KerasModel(artifact=keras_model) skl.store(name='nn') assert os.path.exists(os.path.join(skl.model_path, 'nn' + '.h5')) assert os.path.exists(os.path.join(skl.model_path, 'nn' + '.json')) assert os.path.exists(os.path.join(skl.model_path, 'tf.txt', 'saved_model' + '.pb')) assert os.path.isdir(os.path.join(skl.model_path, 'tf.txt', 'variables')) for root, dirs, files in os.walk(project_manager.CONFIG['saved-models']): for f in files: os.unlink(os.path.join(root, f)) for d in dirs: shutil.rmtree(os.path.join(root, d)) with open(os.path.join(project_manager.CONFIG['saved-models'], '.gitkeep'), 'w') as gitkeep: gitkeep.write('empty')
def test_mulitple_serializations_second(keras_model): skl2 = KerasModel(artifact=keras_model) skl2.store(name='nn2') assert os.path.exists(skl2.model_path)
def test_mulitple_serializations_first(keras_model): skl1 = KerasModel(artifact=keras_model) skl1.store(name='nn1') assert os.path.exists(skl1.model_path)
def serializer(): skl = KerasModel(artifact=keras_model) skl.store(name='nn')