def load_external_file( self, artifact_name: str, save_pattern: str, cls: Optional[Type] = None ) -> Any: """ Define pattern for loading external files returns the object for assignment Inverted operation from saving. Registered functions should take in the same data (in the same form) of what is saved in the filepath """ if cls is None: # Look up in registry load_cls = LOAD_METHOD_REGISTRY.get(save_pattern) else: LOGGER.info("Custom load class passed, skipping registry lookup") load_cls = cls if load_cls is None: raise SimpleMLError(f"No registered load class for {save_pattern}") # Do some validation in case attempting to load unsaved artifact artifact = self.filepaths.get(artifact_name, None) if artifact is None: raise SimpleMLError(f"No artifact saved for {artifact_name}") if save_pattern not in artifact: raise SimpleMLError( f"No artifact saved using save pattern {save_pattern} for {artifact_name}" ) filepath_data = artifact[save_pattern] if not isinstance(filepath_data, dict): # legacy wrap for old filepath formats filepath_data = {"legacy": filepath_data} return load_cls.load(**filepath_data)
def test_registering_with_decorator_without_parameters(self): save_pattern = "fake_decorated_without_parameters" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern class FakeSavePattern(object): SAVE_PATTERN = save_pattern self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern)
def test_registering_new_load_pattern_with_decorator_implicitly(self): """ Decorator test with class attribute for load pattern """ save_pattern = "fake_implicit_decorated_load_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save=False, load=True) class FakeSavePattern(object): SAVE_PATTERN = save_pattern pass self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
def test_registering_new_save_pattern_implicitly(self): """ test with class attribute for save pattern """ save_pattern = "fake_implicit_save_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) class FakeSavePattern(object): SAVE_PATTERN = save_pattern register_save_pattern(FakeSavePattern, save=True, load=False) self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
def test_registering_new_save_pattern_with_decorator_explicitly(self): """ Decorator test with save pattern parameter """ save_pattern = "fake_explicit_decorated_save_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save_pattern, save=True, load=False) class FakeSavePattern(object): pass self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
def test_registering_both_patterns_with_decorator_implicitly(self): """ Decorator test without pattern parameter """ save_pattern = "fake_implicit_decorated_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) @SavePatternDecorators.register_save_pattern(save=True, load=True) class FakeSavePattern(object): SAVE_PATTERN = save_pattern self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern)
def test_registering_new_load_pattern_explicitly(self): """ test with load pattern parameter """ save_pattern = "fake_explicit_load_pattern" SAVE_METHOD_REGISTRY.drop(save_pattern) LOAD_METHOD_REGISTRY.drop(save_pattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry) self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry) class FakeSavePattern(object): pass register_save_pattern(FakeSavePattern, save_pattern=save_pattern, save=False, load=True) self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry) self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern), FakeSavePattern) self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
def deregister_save_pattern( cls: Optional[Type] = None, save_pattern: Optional[str] = None, save: Optional[bool] = True, load: Optional[bool] = True, ) -> None: """ Deregister the class to use for saving and loading for the particular pattern :param save_pattern: the optional string denoting the pattern this class implements (e.g. `disk_pickled`). Checks class attribute `cls.SAVE_PATTERN` if null :param save: optional bool; default true; whether to remove the class as the save method for the registered save pattern :param load: optional bool; default true; whether to remove the class as the load method for the registered save pattern """ if save_pattern is None: if not hasattr(cls, "SAVE_PATTERN"): raise SimpleMLError( "Cannot deregister save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`" ) save_pattern = cls.SAVE_PATTERN # Independent deregistration for saving and loading if save and save_pattern in SAVE_METHOD_REGISTRY.registry: if cls is not None and SAVE_METHOD_REGISTRY.get(save_pattern) != cls: LOGGER.warning( f"Deregistering {save_pattern} as save pattern but passed class does not match registered class" ) SAVE_METHOD_REGISTRY.drop(save_pattern) if load and save_pattern in LOAD_METHOD_REGISTRY.registry: if cls is not None and LOAD_METHOD_REGISTRY.get(save_pattern) != cls: LOGGER.warning( f"Deregistering {save_pattern} as load pattern but passed class does not match registered class" ) LOAD_METHOD_REGISTRY.drop(save_pattern)
def register_save_pattern( cls: Type, save_pattern: Optional[str] = None, save: Optional[bool] = True, load: Optional[bool] = True, overwrite: Optional[bool] = False, ) -> None: """ Register the class to use for saving and loading for the particular pattern IT IS ALLOWABLE TO HAVE DIFFERENT CLASSES HANDLE SAVING AND LOADING FOR THE SAME REGISTERED PATTERN :param save_pattern: the optional string denoting the pattern this class implements (e.g. `disk_pickled`). Checks class attribute `cls.SAVE_PATTERN` if null :param save: optional bool; default true; whether to use the decorated class as the save method for the registered save pattern :param load: optional bool; default true; whether to use the decorated class as the load method for the registered save pattern :param overwrite: optional bool; default false; whether to overwrite the the registered class for the save pattern, if it exists. Otherwise throw an error """ if save_pattern is None: if not hasattr(cls, "SAVE_PATTERN"): raise SimpleMLError( "Cannot register save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`" ) save_pattern = cls.SAVE_PATTERN # Independent registration for saving and loading if save: SAVE_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite) if load: LOAD_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite)