def save_external_file( self, artifact_name: str, save_pattern: str, cls: Optional[Type] = None, **save_params, ) -> None: """ Abstracted pattern to save an artifact via one of the registered patterns and update the filepaths location """ if cls is None: # Look up in registry save_cls = SAVE_METHOD_REGISTRY.get(save_pattern) else: LOGGER.info("Custom save class passed, skipping registry lookup") save_cls = cls if save_cls is None: raise SimpleMLError(f"No registered save pattern for {save_pattern}") filepath_data = save_cls.save(artifact_name=artifact_name, **save_params) # Update filepaths if self.filepaths is None: self.filepaths = {} if self.filepaths.get(artifact_name, None) is None: self.filepaths[artifact_name] = {} self.filepaths[artifact_name][save_pattern] = filepath_data
def test_registering_with_decorator_without_parameters(self): save_pattern = "fake_decorated_without_parameters" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern class FakeSavePattern(object): SAVE_PATTERN = save_pattern self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern)
def test_registering_new_load_pattern_with_decorator_implicitly(self): """ Decorator test with class attribute for load pattern """ save_pattern = "fake_implicit_decorated_load_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save=False, load=True) class FakeSavePattern(object): SAVE_PATTERN = save_pattern pass self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
def test_registering_new_save_pattern_implicitly(self): """ test with class attribute for save pattern """ save_pattern = "fake_implicit_save_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) class FakeSavePattern(object): SAVE_PATTERN = save_pattern register_save_pattern(FakeSavePattern, save=True, load=False) self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
def test_registering_new_save_pattern_with_decorator_explicitly(self): """ Decorator test with save pattern parameter """ save_pattern = "fake_explicit_decorated_save_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save_pattern, save=True, load=False) class FakeSavePattern(object): pass self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
def test_registering_both_patterns_with_decorator_implicitly(self): """ Decorator test without pattern parameter """ save_pattern = "fake_implicit_decorated_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save=True, load=True) class FakeSavePattern(object): SAVE_PATTERN = save_pattern self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern)
def test_registering_new_load_pattern_explicitly(self): """ test with load pattern parameter """ save_pattern = "fake_explicit_load_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) class FakeSavePattern(object): pass register_save_pattern(FakeSavePattern, save_pattern=save_pattern, save=False, load=True) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
def deregister_save_pattern( cls: Optional[Type] = None, save_pattern: Optional[str] = None, save: Optional[bool] = True, load: Optional[bool] = True, ) -> None: """ Deregister the class to use for saving and loading for the particular pattern :param save_pattern: the optional string denoting the pattern this class implements (e.g. `disk_pickled`). Checks class attribute `cls.SAVE_PATTERN` if null :param save: optional bool; default true; whether to remove the class as the save method for the registered save pattern :param load: optional bool; default true; whether to remove the class as the load method for the registered save pattern """ if save_pattern is None: if not hasattr(cls, "SAVE_PATTERN"): raise SimpleMLError( "Cannot deregister save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`" ) save_pattern = cls.SAVE_PATTERN # Independent deregistration for saving and loading if save and save_pattern in SAVE_METHOD_REGISTRY.registry: if cls is not None and SAVE_METHOD_REGISTRY.get(save_pattern) != cls: LOGGER.warning( f"Deregistering {save_pattern} as save pattern but passed class does not match registered class" ) SAVE_METHOD_REGISTRY.drop(save_pattern) if load and save_pattern in LOAD_METHOD_REGISTRY.registry: if cls is not None and LOAD_METHOD_REGISTRY.get(save_pattern) != cls: LOGGER.warning( f"Deregistering {save_pattern} as load pattern but passed class does not match registered class" ) LOAD_METHOD_REGISTRY.drop(save_pattern)
def register_save_pattern( cls: Type, save_pattern: Optional[str] = None, save: Optional[bool] = True, load: Optional[bool] = True, overwrite: Optional[bool] = False, ) -> None: """ Register the class to use for saving and loading for the particular pattern IT IS ALLOWABLE TO HAVE DIFFERENT CLASSES HANDLE SAVING AND LOADING FOR THE SAME REGISTERED PATTERN :param save_pattern: the optional string denoting the pattern this class implements (e.g. `disk_pickled`). Checks class attribute `cls.SAVE_PATTERN` if null :param save: optional bool; default true; whether to use the decorated class as the save method for the registered save pattern :param load: optional bool; default true; whether to use the decorated class as the load method for the registered save pattern :param overwrite: optional bool; default false; whether to overwrite the the registered class for the save pattern, if it exists. Otherwise throw an error """ if save_pattern is None: if not hasattr(cls, "SAVE_PATTERN"): raise SimpleMLError( "Cannot register save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`" ) save_pattern = cls.SAVE_PATTERN # Independent registration for saving and loading if save: SAVE_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite) if load: LOAD_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite)