예제 #1
0
    def __init__(self, config, module, dir):
        try:
            super().__init__(module, dir)
        except PermissionError as e:
            raise InvalidAnalyserConfigError(
                "You must provide a valid directory path")

        if not "elements_in" in config:
            raise InvalidAnalyserConfigError(
                "The config must contain an 'elements_in' indicating the analyser's input."
            )
        elif type(config["elements_in"]) is not list or len(
                config["elements_in"]) is 0:
            raise InvalidAnalyserConfigError(
                "The 'elements_in' must be a list containing at least one string"
            )

        if type(module) is not str or module == "":
            raise InvalidAnalyserConfigError(
                "You must provide a name for your analyser")

        if type(dir) is not str:
            raise InvalidAnalyserConfigError(
                "You must provide a valid directory path")

        self.CONFIG = config
예제 #2
0
    def __init__(self, config, module, storage=None):
        super().__init__(config, module, storage)

        if not isinstance(module, str) or module == "":
            raise InvalidAnalyserConfigError(
                "You must provide a name for your analyser")
        if not isinstance(storage, Storage):
            raise InvalidAnalyserConfigError(
                "You must provide a valid storage object")
        if not "elements_in" in config:
            raise InvalidAnalyserConfigError(
                "The config must contain an 'elements_in' indicating the analyser's input."
            )
        if not config["elements_in"] or not isinstance(config["elements_in"],
                                                       list):
            raise InvalidAnalyserConfigError(
                "The 'elements_in' must be a list containing at least one string"
            )
예제 #3
0
파일: core.py 프로젝트: Smoltbob/mtriage
    def pre_analyse(self, config):
        self.logger(config["model"])
        self.logger(f"Storing models in {KERAS_HOME}")
        MOD = SUPPORTED_MODELS.get(config["model"])
        if MOD is None:
            raise InvalidAnalyserConfigError(
                f"The module '{config['model']}' either does not exist, or is not yet supported."
            )

        rLabels = config["labels"]

        # TODO: make it so that this doesn't redownload every run.
        # i.e. refactor it into partial.Dockerfile
        self.model_module = import_module(
            f"keras.applications.{MOD['module']}")
        impmodel = getattr(self.model_module, config["model"])
        # NB: this downloads the weights if they don't exist
        self.model = impmodel(weights="imagenet")
        self.THRESH = 0.1

        # revert to serial if CPU (TODO: debug why parallel CPU doesn't work)
        if not tf.test.is_gpu_available():
            self.in_parallel = False

        def get_preds(img_path):
            img = load_img(img_path, target_size=(224, 224))
            x = img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = self.model_module.preprocess_input(x)
            preds = self.model.predict(x)

            # top field must be included or defaults to 5, huge number ensures
            # it gets all labels
            decoded = self.model_module.decode_predictions(preds, top=10)

            # filter by labels provided in whitelist
            filteredPreds = [p for p in decoded[0] if p[1] in rLabels]

            # return map(lambda x: (x[1], float(x[2])), filteredPreds)
            return [(x[1], float(x[2])) for x in filteredPreds
                    if float(x[2]) >= self.THRESH]

        self.get_preds = get_preds
예제 #4
0
    def __create_hasher(self, config):
        hasher_key = config["method"] if "method" in config else "phash"
        self.logger(f"Compare method is {hasher_key}")
        hasher = {
            "phash": methods.PHash,
            "ahash": methods.AHash,
            "dhash": methods.DHash,
            "whash": methods.WHash,
        }.get(hasher_key)
        if hasher is None:
            raise InvalidAnalyserConfigError(
                f"'{hasher_key}' is not a valid method for imagededup.")

        self.hasher = hasher()

        # super low threshold by default to only remove essentially identical images.
        if "threshold" in config:
            self.threshold = int(config["threshold"])
        else:
            self.threshold = 3

        self.logger(f"Hamming threshold is {self.threshold}")
예제 #5
0
    def pre_analyse(self, config):
        self.logger(config["model"])
        MOD = SUPPORTED_MODELS.get(config["model"])
        if MOD is None:
            raise InvalidAnalyserConfigError(
                f"The module '{config['model']}' either does not exist, or is not yet supported."
            )

        rLabels = config["labels"]

        self.model_module = import_module(f"keras.applications.{MOD['module']}")
        impmodel = getattr(self.model_module, config["model"])
        # NB: this downloads the weights if they don't exist
        self.model = impmodel(weights="imagenet")
        self.THRESH = 0.1

        def get_preds(img_path):
            img = load_img(img_path, target_size=(224, 224))
            x = img_to_array(img)
            x = np.expand_dims(x, axis=0)
            x = self.model_module.preprocess_input(x)
            preds = self.model.predict(x)

            # top field must be included or defaults to 5, huge number ensures
            # it gets all labels
            decoded = self.model_module.decode_predictions(preds, top=10)

            # filter by labels provided in whitelist
            filteredPreds = [p for p in decoded[0] if p[1] in rLabels]

            # return map(lambda x: (x[1], float(x[2])), filteredPreds)
            return [
                (x[1], float(x[2])) for x in filteredPreds if float(x[2]) >= self.THRESH
            ]

        self.get_preds = get_preds