示例#1
0
    def save_external_file(
        self,
        artifact_name: str,
        save_pattern: str,
        cls: Optional[Type] = None,
        **save_params,
    ) -> None:
        """
        Abstracted pattern to save an artifact via one of the registered
        patterns and update the filepaths location
        """
        if cls is None:
            # Look up in registry
            save_cls = SAVE_METHOD_REGISTRY.get(save_pattern)
        else:
            LOGGER.info("Custom save class passed, skipping registry lookup")
            save_cls = cls

        if save_cls is None:
            raise SimpleMLError(f"No registered save pattern for {save_pattern}")

        filepath_data = save_cls.save(artifact_name=artifact_name, **save_params)

        # Update filepaths
        if self.filepaths is None:
            self.filepaths = {}
        if self.filepaths.get(artifact_name, None) is None:
            self.filepaths[artifact_name] = {}
        self.filepaths[artifact_name][save_pattern] = filepath_data
示例#2
0
    def test_registering_with_decorator_without_parameters(self):
        save_pattern = "fake_decorated_without_parameters"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        @SavePatternDecorators.register_save_pattern
        class FakeSavePattern(object):
            SAVE_PATTERN = save_pattern

        self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
        self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
示例#3
0
    def test_registering_new_load_pattern_with_decorator_implicitly(self):
        """
        Decorator test with class attribute for load pattern
        """
        save_pattern = "fake_implicit_decorated_load_pattern"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        @SavePatternDecorators.register_save_pattern(save=False, load=True)
        class FakeSavePattern(object):
            SAVE_PATTERN = save_pattern
            pass

        self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
        self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
示例#4
0
    def test_registering_new_save_pattern_implicitly(self):
        """
        test with class attribute for save pattern
        """
        save_pattern = "fake_implicit_save_pattern"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        class FakeSavePattern(object):
            SAVE_PATTERN = save_pattern

        register_save_pattern(FakeSavePattern, save=True, load=False)

        self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
示例#5
0
    def test_registering_new_save_pattern_with_decorator_explicitly(self):
        """
        Decorator test with save pattern parameter
        """
        save_pattern = "fake_explicit_decorated_save_pattern"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        @SavePatternDecorators.register_save_pattern(save_pattern,
                                                     save=True,
                                                     load=False)
        class FakeSavePattern(object):
            pass

        self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
示例#6
0
    def test_registering_both_patterns_with_decorator_implicitly(self):
        """
        Decorator test without pattern parameter
        """
        save_pattern = "fake_implicit_decorated_pattern"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        @SavePatternDecorators.register_save_pattern(save=True, load=True)
        class FakeSavePattern(object):
            SAVE_PATTERN = save_pattern

        self.assertIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertEqual(SAVE_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
        self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
示例#7
0
    def test_registering_new_load_pattern_explicitly(self):
        """
        test with load pattern parameter
        """
        save_pattern = "fake_explicit_load_pattern"
        SAVE_METHOD_REGISTRY.drop(save_pattern)
        LOAD_METHOD_REGISTRY.drop(save_pattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
        self.assertNotIn(save_pattern, LOAD_METHOD_REGISTRY.registry)

        class FakeSavePattern(object):
            pass

        register_save_pattern(FakeSavePattern,
                              save_pattern=save_pattern,
                              save=False,
                              load=True)

        self.assertIn(save_pattern, LOAD_METHOD_REGISTRY.registry)
        self.assertEqual(LOAD_METHOD_REGISTRY.get(save_pattern),
                         FakeSavePattern)
        self.assertNotIn(save_pattern, SAVE_METHOD_REGISTRY.registry)
示例#8
0
def deregister_save_pattern(
    cls: Optional[Type] = None,
    save_pattern: Optional[str] = None,
    save: Optional[bool] = True,
    load: Optional[bool] = True,
) -> None:
    """
    Deregister the class to use for saving and
    loading for the particular pattern

    :param save_pattern: the optional string denoting the pattern this
        class implements (e.g. `disk_pickled`). Checks class attribute
        `cls.SAVE_PATTERN` if null
    :param save: optional bool; default true; whether to remove the
        class as the save method for the registered save pattern
    :param load: optional bool; default true; whether to remove the
        class as the load method for the registered save pattern
    """
    if save_pattern is None:
        if not hasattr(cls, "SAVE_PATTERN"):
            raise SimpleMLError(
                "Cannot deregister save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`"
            )
        save_pattern = cls.SAVE_PATTERN

    # Independent deregistration for saving and loading
    if save and save_pattern in SAVE_METHOD_REGISTRY.registry:
        if cls is not None and SAVE_METHOD_REGISTRY.get(save_pattern) != cls:
            LOGGER.warning(
                f"Deregistering {save_pattern} as save pattern but passed class does not match registered class"
            )
        SAVE_METHOD_REGISTRY.drop(save_pattern)

    if load and save_pattern in LOAD_METHOD_REGISTRY.registry:
        if cls is not None and LOAD_METHOD_REGISTRY.get(save_pattern) != cls:
            LOGGER.warning(
                f"Deregistering {save_pattern} as load pattern but passed class does not match registered class"
            )
        LOAD_METHOD_REGISTRY.drop(save_pattern)
示例#9
0
def register_save_pattern(
    cls: Type,
    save_pattern: Optional[str] = None,
    save: Optional[bool] = True,
    load: Optional[bool] = True,
    overwrite: Optional[bool] = False,
) -> None:
    """
    Register the class to use for saving and
    loading for the particular pattern

    IT IS ALLOWABLE TO HAVE DIFFERENT CLASSES HANDLE SAVING AND LOADING FOR
    THE SAME REGISTERED PATTERN

    :param save_pattern: the optional string denoting the pattern this
        class implements (e.g. `disk_pickled`). Checks class attribute
        `cls.SAVE_PATTERN` if null
    :param save: optional bool; default true; whether to use the decorated
        class as the save method for the registered save pattern
    :param load: optional bool; default true; whether to use the decorated
        class as the load method for the registered save pattern
    :param overwrite: optional bool; default false; whether to overwrite the
        the registered class for the save pattern, if it exists. Otherwise throw
        an error
    """
    if save_pattern is None:
        if not hasattr(cls, "SAVE_PATTERN"):
            raise SimpleMLError(
                "Cannot register save pattern without passing the `save_pattern` parameter or setting the class attribute `cls.SAVE_PATTERN`"
            )
        save_pattern = cls.SAVE_PATTERN

    # Independent registration for saving and loading
    if save:
        SAVE_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite)

    if load:
        LOAD_METHOD_REGISTRY.register(save_pattern, cls, allow_duplicates=overwrite)