示例#1
0
 def process(self):
     """ Perform the Restore process """
     logger.info("Starting Model Restore...")
     self.validate()
     backup = Backup(self.model_dir, self.model_name)
     backup.restore()
     logger.info("Completed Model Restore")
示例#2
0
文件: io.py 项目: wei/faceswap
 def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool,
              save_optimizer: Literal["never", "always", "exit"]) -> None:
     self._plugin = plugin
     self._is_predict = is_predict
     self._model_dir = model_dir
     self._save_optimizer = save_optimizer
     self._history: List[List[float]] = [
         [], []
     ]  # Loss histories per save iteration
     self._backup = Backup(self._model_dir, self._plugin.name)
示例#3
0
    def __init__(self,
                 model_dir,
                 gpus=1,
                 configfile=None,
                 snapshot_interval=0,
                 no_logs=False,
                 warp_to_landmarks=False,
                 augment_color=True,
                 no_flip=False,
                 training_image_size=256,
                 alignments_paths=None,
                 preview_scale=100,
                 input_shape=None,
                 encoder_dim=None,
                 trainer="original",
                 pingpong=False,
                 memory_saving_gradients=False,
                 optimizer_savings=False,
                 predict=False):
        logger.debug(
            "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
            "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
            "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
            "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
            "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, "
            "predict: %s)", self.__class__.__name__, model_dir, gpus,
            configfile, snapshot_interval, no_logs, warp_to_landmarks,
            augment_color, no_flip, training_image_size, alignments_paths,
            preview_scale, input_shape, encoder_dim, trainer, pingpong,
            memory_saving_gradients, optimizer_savings, predict)

        self.predict = predict
        self.model_dir = model_dir
        self.vram_savings = VRAMSavings(pingpong, optimizer_savings,
                                        memory_saving_gradients)

        self.backup = Backup(self.model_dir, self.name)
        self.gpus = gpus
        self.configfile = configfile
        self.input_shape = input_shape
        self.encoder_dim = encoder_dim
        self.trainer = trainer

        self.load_config(
        )  # Load config if plugin has not already referenced it

        self.state = State(self.model_dir, self.name,
                           self.config_changeable_items, no_logs,
                           self.vram_savings.pingpong, training_image_size)

        self.blocks = NNBlocks(
            use_subpixel=self.config["subpixel_upscaling"],
            use_icnr_init=self.config["icnr_init"],
            use_convaware_init=self.config["conv_aware_init"],
            use_reflect_padding=self.config["reflect_padding"],
            first_run=self.state.first_run)

        self.is_legacy = False
        self.rename_legacy()
        self.load_state_info()

        self.networks = dict()  # Networks for the model
        self.predictors = dict()  # Predictors for model
        self.history = dict()  # Loss history per save iteration)

        # Training information specific to the model should be placed in this
        # dict for reference by the trainer.
        self.training_opts = {
            "alignments":
            alignments_paths,
            "preview_scaling":
            preview_scale / 100,
            "warp_to_landmarks":
            warp_to_landmarks,
            "augment_color":
            augment_color,
            "no_flip":
            no_flip,
            "pingpong":
            self.vram_savings.pingpong,
            "snapshot_interval":
            snapshot_interval,
            "training_size":
            self.state.training_size,
            "no_logs":
            self.state.current_session["no_logs"],
            "coverage_ratio":
            self.calculate_coverage_ratio(),
            "mask_type":
            self.config["mask_type"],
            "mask_blur_kernel":
            self.config["mask_blur_kernel"],
            "mask_threshold":
            self.config["mask_threshold"],
            "learn_mask": (self.config["learn_mask"]
                           and self.config["mask_type"] is not None),
            "penalized_mask_loss": (self.config["penalized_mask_loss"]
                                    and self.config["mask_type"] is not None)
        }
        logger.debug("training_opts: %s", self.training_opts)

        if self.multiple_models_in_folder:
            deprecation_warning(
                "Support for multiple model types within the same folder",
                additional_info=
                "Please split each model into separate folders to "
                "avoid issues in future.")

        self.build()
        logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
示例#4
0
class ModelBase():
    """ Base class that all models should inherit from """
    def __init__(self,
                 model_dir,
                 gpus=1,
                 configfile=None,
                 snapshot_interval=0,
                 no_logs=False,
                 warp_to_landmarks=False,
                 augment_color=True,
                 no_flip=False,
                 training_image_size=256,
                 alignments_paths=None,
                 preview_scale=100,
                 input_shape=None,
                 encoder_dim=None,
                 trainer="original",
                 pingpong=False,
                 memory_saving_gradients=False,
                 optimizer_savings=False,
                 predict=False):
        logger.debug(
            "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
            "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
            "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
            "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
            "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, "
            "predict: %s)", self.__class__.__name__, model_dir, gpus,
            configfile, snapshot_interval, no_logs, warp_to_landmarks,
            augment_color, no_flip, training_image_size, alignments_paths,
            preview_scale, input_shape, encoder_dim, trainer, pingpong,
            memory_saving_gradients, optimizer_savings, predict)

        self.predict = predict
        self.model_dir = model_dir
        self.vram_savings = VRAMSavings(pingpong, optimizer_savings,
                                        memory_saving_gradients)

        self.backup = Backup(self.model_dir, self.name)
        self.gpus = gpus
        self.configfile = configfile
        self.input_shape = input_shape
        self.encoder_dim = encoder_dim
        self.trainer = trainer

        self.load_config(
        )  # Load config if plugin has not already referenced it

        self.state = State(self.model_dir, self.name,
                           self.config_changeable_items, no_logs,
                           self.vram_savings.pingpong, training_image_size)

        self.blocks = NNBlocks(
            use_subpixel=self.config["subpixel_upscaling"],
            use_icnr_init=self.config["icnr_init"],
            use_convaware_init=self.config["conv_aware_init"],
            use_reflect_padding=self.config["reflect_padding"],
            first_run=self.state.first_run)

        self.is_legacy = False
        self.rename_legacy()
        self.load_state_info()

        self.networks = dict()  # Networks for the model
        self.predictors = dict()  # Predictors for model
        self.history = dict()  # Loss history per save iteration)

        # Training information specific to the model should be placed in this
        # dict for reference by the trainer.
        self.training_opts = {
            "alignments":
            alignments_paths,
            "preview_scaling":
            preview_scale / 100,
            "warp_to_landmarks":
            warp_to_landmarks,
            "augment_color":
            augment_color,
            "no_flip":
            no_flip,
            "pingpong":
            self.vram_savings.pingpong,
            "snapshot_interval":
            snapshot_interval,
            "training_size":
            self.state.training_size,
            "no_logs":
            self.state.current_session["no_logs"],
            "coverage_ratio":
            self.calculate_coverage_ratio(),
            "mask_type":
            self.config["mask_type"],
            "mask_blur_kernel":
            self.config["mask_blur_kernel"],
            "mask_threshold":
            self.config["mask_threshold"],
            "learn_mask": (self.config["learn_mask"]
                           and self.config["mask_type"] is not None),
            "penalized_mask_loss": (self.config["penalized_mask_loss"]
                                    and self.config["mask_type"] is not None)
        }
        logger.debug("training_opts: %s", self.training_opts)

        if self.multiple_models_in_folder:
            deprecation_warning(
                "Support for multiple model types within the same folder",
                additional_info=
                "Please split each model into separate folders to "
                "avoid issues in future.")

        self.build()
        logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)

    @property
    def config_section(self):
        """ The section name for loading config """
        retval = ".".join(self.__module__.split(".")[-2:])
        logger.debug(retval)
        return retval

    @property
    def config(self):
        """ Return config dict for current plugin """
        global _CONFIG  # pylint: disable=global-statement
        if not _CONFIG:
            model_name = self.config_section
            logger.debug("Loading config for: %s", model_name)
            _CONFIG = Config(model_name,
                             configfile=self.configfile).config_dict
        return _CONFIG

    @property
    def config_changeable_items(self):
        """ Return the dict of config items that can be updated after the model
            has been created """
        return Config(self.config_section,
                      configfile=self.configfile).changeable_items

    @property
    def name(self):
        """ Set the model name based on the subclass """
        basename = os.path.basename(sys.modules[self.__module__].__file__)
        retval = os.path.splitext(basename)[0].lower()
        logger.debug("model name: '%s'", retval)
        return retval

    @property
    def models_exist(self):
        """ Return if all files exist and clear session """
        retval = all([
            os.path.isfile(model.filename) for model in self.networks.values()
        ])
        logger.debug("Pre-existing models exist: %s", retval)
        return retval

    @property
    def multiple_models_in_folder(self):
        """ Return true if there are multiple model types in the same folder, else false """
        model_files = [
            fname for fname in os.listdir(str(self.model_dir))
            if fname.endswith(".h5")
        ]
        retval = False if not model_files else os.path.commonprefix(
            model_files) == ""
        logger.debug("model_files: %s, retval: %s", model_files, retval)
        return retval

    @property
    def output_shapes(self):
        """ Return the output shapes from the main AutoEncoder """
        out = list()
        for predictor in self.predictors.values():
            out.extend(
                [K.int_shape(output)[-3:] for output in predictor.outputs])
            break  # Only get output from one autoencoder. Shapes are the same
        return [tuple(shape) for shape in out]

    @property
    def output_shape(self):
        """ The output shape of the model (shape of largest face output) """
        return self.output_shapes[self.largest_face_index]

    @property
    def largest_face_index(self):
        """ Return the index from model.outputs of the largest face
            Required for multi-output model prediction. The largest face
            is assumed to be the final output
        """
        sizes = [shape[1] for shape in self.output_shapes if shape[2] == 3]
        if not sizes:
            return None
        max_face = max(sizes)
        retval = [
            idx for idx, shape in enumerate(self.output_shapes)
            if shape[1] == max_face and shape[2] == 3
        ][0]
        logger.debug(retval)
        return retval

    @property
    def largest_mask_index(self):
        """ Return the index from model.outputs of the largest mask
            Required for multi-output model prediction. The largest face
            is assumed to be the final output
        """
        sizes = [shape[1] for shape in self.output_shapes if shape[2] == 1]
        if not sizes:
            return None
        max_mask = max(sizes)
        retval = [
            idx for idx, shape in enumerate(self.output_shapes)
            if shape[1] == max_mask and shape[2] == 1
        ][0]
        logger.debug(retval)
        return retval

    @property
    def feed_mask(self):
        """ bool: ``True`` if the model expects a mask to be fed into input otherwise ``False`` """
        return self.config["mask_type"] is not None and (
            self.config["learn_mask"] or self.config["penalized_mask_loss"])

    def load_config(self):
        """ Load the global config for reference in self.config """
        global _CONFIG  # pylint: disable=global-statement
        if not _CONFIG:
            model_name = self.config_section
            logger.debug("Loading config for: %s", model_name)
            _CONFIG = Config(model_name,
                             configfile=self.configfile).config_dict

    def calculate_coverage_ratio(self):
        """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """
        coverage_ratio = self.config.get("coverage", 62.5) / 100
        logger.debug("Requested coverage_ratio: %s", coverage_ratio)
        cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2
        coverage_ratio = cropped_size / self.state.training_size
        logger.debug("Final coverage_ratio: %s", coverage_ratio)
        return coverage_ratio

    def build(self):
        """ Build the model. Override for custom build methods """
        self.add_networks()
        self.load_models(swapped=False)
        inputs = self.get_inputs()
        try:
            self.build_autoencoders(inputs)
        except ValueError as err:
            if "must be from the same graph" in str(err).lower():
                msg = (
                    "There was an error loading saved weights. This is most likely due to "
                    "model corruption during a previous save."
                    "\nYou should restore weights from a snapshot or from backup files. "
                    "You can use the 'Restore' Tool to restore from backup.")
                raise FaceswapError(msg) from err
            if "multi_gpu_model" in str(err).lower():
                raise FaceswapError(str(err)) from err
            raise err
        self.log_summary()
        self.compile_predictors(initialize=True)

    def get_inputs(self):
        """ Return the inputs for the model """
        logger.debug("Getting inputs")
        inputs = [Input(shape=self.input_shape, name="face_in")]
        output_network = [
            network for network in self.networks.values() if network.is_output
        ][0]
        if self.feed_mask:
            # TODO penalized mask doesn't have a mask output, so we can't use output shapes
            # mask should always be last output..this needs to be a rule
            mask_shape = output_network.output_shapes[-1]
            inputs.append(
                Input(shape=(mask_shape[1:-1] + (1, )), name="mask_in"))
        logger.debug("Got inputs: %s", inputs)
        return inputs

    def build_autoencoders(self, inputs):
        """ Override for Model Specific autoencoder builds

            Inputs is defined in self.get_inputs() and is standardized for all models
                if will generally be in the order:
                [face (the input for image),
                 mask (the input for mask if it is used)]
        """
        raise NotImplementedError

    def add_networks(self):
        """ Override to add neural networks """
        raise NotImplementedError

    def load_state_info(self):
        """ Load the input shape from state file if it exists """
        logger.debug("Loading Input Shape from State file")
        if not self.state.inputs:
            logger.debug("No input shapes saved. Using model config")
            return
        if not self.state.face_shapes:
            logger.warning(
                "Input shapes stored in State file, but no matches for 'face'."
                "Using model config")
            return
        input_shape = self.state.face_shapes[0]
        logger.debug("Setting input shape from state file: %s", input_shape)
        self.input_shape = input_shape

    def add_network(self, network_type, side, network, is_output=False):
        """ Add a NNMeta object """
        logger.debug(
            "network_type: '%s', side: '%s', network: '%s', is_output: %s",
            network_type, side, network, is_output)
        filename = "{}_{}".format(self.name, network_type.lower())
        name = network_type.lower()
        if side:
            side = side.lower()
            filename += "_{}".format(side.upper())
            name += "_{}".format(side)
        filename += ".h5"
        logger.debug("name: '%s', filename: '%s'", name, filename)
        self.networks[name] = NNMeta(str(self.model_dir / filename),
                                     network_type, side, network, is_output)

    def add_predictor(self, side, model):
        """ Add a predictor to the predictors dictionary """
        logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
        if self.gpus > 1:
            logger.debug("Converting to multi-gpu: side %s", side)
            model = multi_gpu_model(model, self.gpus)
        self.predictors[side] = model
        if not self.state.inputs:
            self.store_input_shapes(model)

    def store_input_shapes(self, model):
        """ Store the input and output shapes to state """
        logger.debug("Adding input shapes to state for model")
        inputs = {
            tensor.name: K.int_shape(tensor)[-3:]
            for tensor in model.inputs
        }
        if not any(inp for inp in inputs.keys() if inp.startswith("face")):
            raise ValueError(
                "No input named 'face' was found. Check your input naming. "
                "Current input names: {}".format(inputs))
        # Make sure they are all ints so that it can be json serialized
        inputs = {
            key: tuple(int(i) for i in val)
            for key, val in inputs.items()
        }
        self.state.inputs = inputs
        logger.debug("Added input shapes: %s", self.state.inputs)

    def reset_pingpong(self):
        """ Reset the models for pingpong training """
        logger.debug("Resetting models")

        # Clear models and graph
        self.predictors = dict()
        K.clear_session()

        # Load Models for current training run
        for model in self.networks.values():
            model.network = Model.from_config(model.config)
            model.network.set_weights(model.weights)

        inputs = self.get_inputs()
        self.build_autoencoders(inputs)
        self.compile_predictors(initialize=False)
        logger.debug("Reset models")

    def compile_predictors(self, initialize=True):
        """ Compile the predictors """
        logger.debug("Compiling Predictors")
        learning_rate = self.config.get("learning_rate", 5e-5)
        optimizer = self.get_optimizer(lr=learning_rate,
                                       beta_1=0.5,
                                       beta_2=0.999)

        for side, model in self.predictors.items():
            loss = Loss(model.inputs, model.outputs)
            model.compile(optimizer=optimizer, loss=loss.funcs)
            if initialize:
                self.state.add_session_loss_names(side, loss.names)
                self.history[side] = list()
        logger.debug("Compiled Predictors. Losses: %s", loss.names)

    def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999):  # pylint: disable=invalid-name
        """ Build and return Optimizer """
        opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2)
        if (self.config.get("clipnorm", False)
                and keras.backend.backend() != "plaidml.keras.backend"):
            # NB: Clipnorm is ballooning VRAM usage, which is not expected behavior
            # and may be a bug in Keras/TF.
            # PlaidML has a bug regarding the clipnorm parameter
            # See: https://github.com/plaidml/plaidml/issues/228
            # Workaround by simply removing it.
            # TODO: Remove this as soon it is fixed in PlaidML.
            opt_kwargs["clipnorm"] = 1.0
        logger.debug("Optimizer kwargs: %s", opt_kwargs)
        return Adam(**opt_kwargs, cpu_mode=self.vram_savings.optimizer_savings)

    def converter(self, swap):
        """ Converter for autoencoder models """
        logger.debug("Getting Converter: (swap: %s)", swap)
        side = "a" if swap else "b"
        model = self.predictors[side]
        if self.predict:
            # Must compile the model to be thread safe
            model._make_predict_function()  # pylint: disable=protected-access
        retval = model.predict
        logger.debug("Got Converter: %s", retval)
        return retval

    @property
    def iterations(self):
        "Get current training iteration number"
        return self.state.iterations

    def map_models(self, swapped):
        """ Map the models for A/B side for swapping """
        logger.debug("Map models: (swapped: %s)", swapped)
        models_map = {"a": dict(), "b": dict()}
        sides = ("a", "b") if not swapped else ("b", "a")
        for network in self.networks.values():
            if network.side == sides[0]:
                models_map["a"][network.type] = network.filename
            if network.side == sides[1]:
                models_map["b"][network.type] = network.filename
        logger.debug("Mapped models: (models_map: %s)", models_map)
        return models_map

    def log_summary(self):
        """ Verbose log the model summaries """
        if self.predict:
            return
        for side in sorted(list(self.predictors.keys())):
            logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
            self.predictors[side].summary(
                print_fn=lambda x: logger.verbose("%s", x))
            for name, nnmeta in self.networks.items():
                if nnmeta.side is not None and nnmeta.side != side:
                    continue
                logger.verbose("%s:", name.title())
                nnmeta.network.summary(
                    print_fn=lambda x: logger.verbose("%s", x))

    def do_snapshot(self):
        """ Perform a model snapshot """
        logger.debug("Performing snapshot")
        self.backup.snapshot_models(self.iterations)
        logger.debug("Performed snapshot")

    def load_models(self, swapped):
        """ Load models from file """
        logger.debug("Load model: (swapped: %s)", swapped)

        if not self.models_exist and not self.predict:
            logger.info("Creating new '%s' model in folder: '%s'", self.name,
                        self.model_dir)
            return None
        if not self.models_exist and self.predict:
            logger.error("Model could not be found in folder '%s'. Exiting",
                         self.model_dir)
            exit(0)

        if not self.is_legacy or not self.predict:
            K.clear_session()
        model_mapping = self.map_models(swapped)
        for network in self.networks.values():
            if not network.side:
                is_loaded = network.load()
            else:
                is_loaded = network.load(
                    fullpath=model_mapping[network.side][network.type])
            if not is_loaded:
                break
        if is_loaded:
            logger.info("Loaded model from disk: '%s'", self.model_dir)
        return is_loaded

    def save_models(self):
        """ Backup and save the models """
        logger.debug("Backing up and saving models")
        save_averages = self.get_save_averages()
        backup_func = self.backup.backup_model if self.should_backup(
            save_averages) else None
        if backup_func:
            logger.info("Backing up models...")
        executor = futures.ThreadPoolExecutor()
        save_threads = [
            executor.submit(network.save, backup_func=backup_func)
            for network in self.networks.values()
        ]
        save_threads.append(
            executor.submit(self.state.save, backup_func=backup_func))
        futures.wait(save_threads)
        # call result() to capture errors
        _ = [thread.result() for thread in save_threads]
        msg = "[Saved models]"
        if save_averages:
            lossmsg = [
                "{}_{}: {:.5f}".format(self.state.loss_names[side][0],
                                       side.capitalize(), save_averages[side])
                for side in sorted(list(save_averages.keys()))
            ]
            msg += " - Average since last save: {}".format(", ".join(lossmsg))
        logger.info(msg)

    def get_save_averages(self):
        """ Return the average loss since the last save iteration and reset historical loss """
        logger.debug("Getting save averages")
        avgs = dict()
        for side, loss in self.history.items():
            if not loss:
                logger.debug("No loss in self.history: %s", side)
                break
            avgs[side] = sum(loss) / len(loss)
            self.history[side] = list()  # Reset historical loss
        logger.debug("Average losses since last save: %s", avgs)
        return avgs

    def should_backup(self, save_averages):
        """ Check whether the loss averages for all losses is the lowest that has been seen.

            This protects against model corruption by only backing up the model
            if any of the loss values have fallen.
            TODO This is not a perfect system. If the model corrupts on save_iteration - 1
            then model may still backup
        """
        backup = True

        if not save_averages:
            logger.debug("No save averages. Not backing up")
            return False

        for side, loss in save_averages.items():
            if not self.state.lowest_avg_loss.get(side, None):
                logger.debug(
                    "Setting initial save iteration loss average for '%s': %s",
                    side, loss)
                self.state.lowest_avg_loss[side] = loss
                continue
            if backup:
                # Only run this if backup is true. All losses must have dropped for a valid backup
                backup = self.check_loss_drop(side, loss)

        logger.debug("Lowest historical save iteration loss average: %s",
                     self.state.lowest_avg_loss)

        if backup:  # Update lowest loss values to the state
            for side, avg_loss in save_averages.items():
                logger.debug(
                    "Updating lowest save iteration average for '%s': %s",
                    side, avg_loss)
                self.state.lowest_avg_loss[side] = avg_loss

        logger.debug("Backing up: %s", backup)
        return backup

    def check_loss_drop(self, side, avg):
        """ Check whether total loss has dropped since lowest loss """
        if avg < self.state.lowest_avg_loss[side]:
            logger.debug("Loss for '%s' has dropped", side)
            return True
        logger.debug("Loss for '%s' has not dropped", side)
        return False

    def rename_legacy(self):
        """ Legacy Original, LowMem and IAE models had inconsistent naming conventions
            Rename them if they are found and update """
        legacy_mapping = {
            "iae": [("IAE_decoder.h5", "iae_decoder.h5"),
                    ("IAE_encoder.h5", "iae_encoder.h5"),
                    ("IAE_inter_A.h5", "iae_intermediate_A.h5"),
                    ("IAE_inter_B.h5", "iae_intermediate_B.h5"),
                    ("IAE_inter_both.h5", "iae_inter.h5")],
            "original": [("encoder.h5", "original_encoder.h5"),
                         ("decoder_A.h5", "original_decoder_A.h5"),
                         ("decoder_B.h5", "original_decoder_B.h5"),
                         ("lowmem_encoder.h5", "original_encoder.h5"),
                         ("lowmem_decoder_A.h5", "original_decoder_A.h5"),
                         ("lowmem_decoder_B.h5", "original_decoder_B.h5")]
        }
        if self.name not in legacy_mapping.keys():
            return
        logger.debug("Renaming legacy files")

        set_lowmem = False
        updated = False
        for old_name, new_name in legacy_mapping[self.name]:
            old_path = os.path.join(str(self.model_dir), old_name)
            new_path = os.path.join(str(self.model_dir), new_name)
            if os.path.exists(old_path) and not os.path.exists(new_path):
                logger.info("Updating legacy model name from: '%s' to '%s'",
                            old_name, new_name)
                os.rename(old_path, new_path)
                if old_name.startswith("lowmem"):
                    set_lowmem = True
                updated = True

        if not updated:
            logger.debug("No legacy files to rename")
            return

        self.is_legacy = True
        logger.debug("Creating state file for legacy model")
        self.state.inputs = {"face:0": [64, 64, 3]}
        self.state.training_size = 256
        self.state.config["coverage"] = 62.5
        self.state.config["subpixel_upscaling"] = False
        self.state.config["reflect_padding"] = False
        self.state.config["mask_type"] = None
        self.state.config["mask_blur_kernel"] = 3
        self.state.config["mask_threshold"] = 4
        self.state.config["learn_mask"] = False
        self.state.config["lowmem"] = False
        self.encoder_dim = 1024

        if set_lowmem:
            logger.debug(
                "Setting encoder_dim and lowmem flag for legacy lowmem model")
            self.encoder_dim = 512
            self.state.config["lowmem"] = True

        self.state.replace_config(self.config_changeable_items)
        self.state.save()
示例#5
0
    def __init__(self,
                 model_dir,
                 gpus,
                 configfile=None,
                 snapshot_interval=0,
                 no_logs=False,
                 warp_to_landmarks=False,
                 augment_color=True,
                 no_flip=False,
                 training_image_size=256,
                 alignments_paths=None,
                 preview_scale=100,
                 input_shape=None,
                 encoder_dim=None,
                 trainer="original",
                 pingpong=False,
                 pretrain=False,
                 memory_saving_gradients=False,
                 predict=False):
        logger.debug(
            "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
            "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
            "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
            "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
            "pingpong: %s, memory_saving_gradients: %s, predict: %s)",
            self.__class__.__name__, model_dir, gpus, configfile,
            snapshot_interval, no_logs, warp_to_landmarks, augment_color,
            no_flip, training_image_size, alignments_paths, preview_scale,
            input_shape, encoder_dim, trainer, pingpong,
            memory_saving_gradients, predict)

        self.predict = predict
        self.model_dir = model_dir
        self.backup = Backup(self.model_dir, self.name)
        self.gpus = gpus
        self.configfile = configfile
        self.blocks = NNBlocks(
            use_subpixel=self.config["subpixel_upscaling"],
            use_icnr_init=self.config["icnr_init"],
            use_reflect_padding=self.config["reflect_padding"])
        self.input_shape = input_shape
        self.output_shape = None  # set after model is compiled
        self.encoder_dim = encoder_dim
        self.trainer = trainer

        self.state = State(self.model_dir, self.name,
                           self.config_changeable_items, no_logs, pingpong,
                           training_image_size)
        self.is_legacy = False
        self.rename_legacy()
        self.load_state_info()

        self.networks = dict()  # Networks for the model
        self.predictors = dict()  # Predictors for model
        self.history = dict()  # Loss history per save iteration)

        # Training information specific to the model should be placed in this
        # dict for reference by the trainer.
        self.training_opts = {
            "alignments": alignments_paths,
            "preview_scaling": preview_scale / 100,
            "warp_to_landmarks": warp_to_landmarks,
            "augment_color": augment_color,
            "no_flip": no_flip,
            "pingpong": pingpong,
            "snapshot_interval": snapshot_interval
        }

        self.set_gradient_type(memory_saving_gradients)
        self.build()
        self.set_training_data()
        logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
示例#6
0
class ModelBase():
    """ Base class that all models should inherit from """
    def __init__(self,
                 model_dir,
                 gpus,
                 configfile=None,
                 snapshot_interval=0,
                 no_logs=False,
                 warp_to_landmarks=False,
                 augment_color=True,
                 no_flip=False,
                 training_image_size=256,
                 alignments_paths=None,
                 preview_scale=100,
                 input_shape=None,
                 encoder_dim=None,
                 trainer="original",
                 pingpong=False,
                 pretrain=False,
                 memory_saving_gradients=False,
                 predict=False):
        logger.debug(
            "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
            "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
            "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
            "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
            "pingpong: %s, memory_saving_gradients: %s, predict: %s)",
            self.__class__.__name__, model_dir, gpus, configfile,
            snapshot_interval, no_logs, warp_to_landmarks, augment_color,
            no_flip, training_image_size, alignments_paths, preview_scale,
            input_shape, encoder_dim, trainer, pingpong,
            memory_saving_gradients, predict)

        self.predict = predict
        self.model_dir = model_dir
        self.backup = Backup(self.model_dir, self.name)
        self.gpus = gpus
        self.configfile = configfile
        self.blocks = NNBlocks(
            use_subpixel=self.config["subpixel_upscaling"],
            use_icnr_init=self.config["icnr_init"],
            use_reflect_padding=self.config["reflect_padding"])
        self.input_shape = input_shape
        self.output_shape = None  # set after model is compiled
        self.encoder_dim = encoder_dim
        self.trainer = trainer

        self.state = State(self.model_dir, self.name,
                           self.config_changeable_items, no_logs, pingpong,
                           training_image_size)
        self.is_legacy = False
        self.rename_legacy()
        self.load_state_info()

        self.networks = dict()  # Networks for the model
        self.predictors = dict()  # Predictors for model
        self.history = dict()  # Loss history per save iteration)

        # Training information specific to the model should be placed in this
        # dict for reference by the trainer.
        self.training_opts = {
            "alignments": alignments_paths,
            "preview_scaling": preview_scale / 100,
            "warp_to_landmarks": warp_to_landmarks,
            "augment_color": augment_color,
            "no_flip": no_flip,
            "pingpong": pingpong,
            "snapshot_interval": snapshot_interval
        }

        self.set_gradient_type(memory_saving_gradients)
        self.build()
        self.set_training_data()
        logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)

    @property
    def config_section(self):
        """ The section name for loading config """
        retval = ".".join(self.__module__.split(".")[-2:])
        logger.debug(retval)
        return retval

    @property
    def config(self):
        """ Return config dict for current plugin """
        global _CONFIG  # pylint: disable=global-statement
        if not _CONFIG:
            model_name = self.config_section
            logger.debug("Loading config for: %s", model_name)
            _CONFIG = Config(model_name,
                             configfile=self.configfile).config_dict
        return _CONFIG

    @property
    def config_changeable_items(self):
        """ Return the dict of config items that can be updated after the model
            has been created """
        return Config(self.config_section,
                      configfile=self.configfile).changeable_items

    @property
    def name(self):
        """ Set the model name based on the subclass """
        basename = os.path.basename(sys.modules[self.__module__].__file__)
        retval = os.path.splitext(basename)[0].lower()
        logger.debug("model name: '%s'", retval)
        return retval

    @property
    def models_exist(self):
        """ Return if all files exist and clear session """
        retval = all([
            os.path.isfile(model.filename) for model in self.networks.values()
        ])
        logger.debug("Pre-existing models exist: %s", retval)
        return retval

    @staticmethod
    def set_gradient_type(memory_saving_gradients):
        """ Monkeypatch Memory Saving Gradients if requested """
        if not memory_saving_gradients:
            return
        logger.info("Using Memory Saving Gradients")
        from lib.model import memory_saving_gradients
        K.__dict__["gradients"] = memory_saving_gradients.gradients_memory

    def set_training_data(self):
        """ Override to set model specific training data.

            super() this method for defaults otherwise be sure to add """
        logger.debug("Setting training data")
        # Force number of preview images to between 2 and 16
        self.training_opts["training_size"] = self.state.training_size
        self.training_opts["no_logs"] = self.state.current_session["no_logs"]
        self.training_opts["mask_type"] = self.config.get("mask_type", None)
        self.training_opts["coverage_ratio"] = self.calculate_coverage_ratio()
        logger.debug("Set training data: %s", self.training_opts)

    def calculate_coverage_ratio(self):
        """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """
        coverage_ratio = self.config.get("coverage", 62.5) / 100
        logger.debug("Requested coverage_ratio: %s", coverage_ratio)
        cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2
        coverage_ratio = cropped_size / self.state.training_size
        logger.debug("Final coverage_ratio: %s", coverage_ratio)
        return coverage_ratio

    def build(self):
        """ Build the model. Override for custom build methods """
        self.add_networks()
        self.load_models(swapped=False)
        try:
            self.build_autoencoders()
        except ValueError as err:
            if "must be from the same graph" in str(err).lower():
                msg = (
                    "There was an error loading saved weights. This is most likely due to "
                    "model corruption during a previous save."
                    "\nYou should restore weights from a snapshot or from backup files. "
                    "You can use the 'Restore' Tool to restore from backup.")
                raise FaceswapError(msg) from err
        self.log_summary()
        self.compile_predictors(initialize=True)

    def build_autoencoders(self):
        """ Override for Model Specific autoencoder builds

            NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names
            are expected:
                face (the input for image)
                mask (the input for mask if it is used)

        """
        raise NotImplementedError

    def add_networks(self):
        """ Override to add neural networks """
        raise NotImplementedError

    def load_state_info(self):
        """ Load the input shape from state file if it exists """
        logger.debug("Loading Input Shape from State file")
        if not self.state.inputs:
            logger.debug("No input shapes saved. Using model config")
            return
        if not self.state.face_shapes:
            logger.warning(
                "Input shapes stored in State file, but no matches for 'face'."
                "Using model config")
            return
        input_shape = self.state.face_shapes[0]
        logger.debug("Setting input shape from state file: %s", input_shape)
        self.input_shape = input_shape

    def add_network(self, network_type, side, network):
        """ Add a NNMeta object """
        logger.debug("network_type: '%s', side: '%s', network: '%s'",
                     network_type, side, network)
        filename = "{}_{}".format(self.name, network_type.lower())
        name = network_type.lower()
        if side:
            side = side.lower()
            filename += "_{}".format(side.upper())
            name += "_{}".format(side)
        filename += ".h5"
        logger.debug("name: '%s', filename: '%s'", name, filename)
        self.networks[name] = NNMeta(str(self.model_dir / filename),
                                     network_type, side, network)

    def add_predictor(self, side, model):
        """ Add a predictor to the predictors dictionary """
        logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
        if self.gpus > 1:
            logger.debug("Converting to multi-gpu: side %s", side)
            model = multi_gpu_model(model, self.gpus)
        self.predictors[side] = model
        if not self.state.inputs:
            self.store_input_shapes(model)
        if not self.output_shape:
            self.set_output_shape(model)

    def store_input_shapes(self, model):
        """ Store the input and output shapes to state """
        logger.debug("Adding input shapes to state for model")
        inputs = {
            tensor.name: K.int_shape(tensor)[-3:]
            for tensor in model.inputs
        }
        if not any(inp for inp in inputs.keys() if inp.startswith("face")):
            raise ValueError(
                "No input named 'face' was found. Check your input naming. "
                "Current input names: {}".format(inputs))
        self.state.inputs = inputs
        logger.debug("Added input shapes: %s", self.state.inputs)

    def set_output_shape(self, model):
        """ Set the output shape for use in training and convert """
        logger.debug("Setting output shape")
        out = [K.int_shape(tensor)[-3:] for tensor in model.outputs]
        if not out:
            raise ValueError("No outputs found! Check your model.")
        self.output_shape = tuple(out[0])
        logger.debug("Added output shape: %s", self.output_shape)

    def reset_pingpong(self):
        """ Reset the models for pingpong training """
        logger.debug("Resetting models")

        # Clear models and graph
        self.predictors = dict()
        K.clear_session()

        # Load Models for current training run
        for model in self.networks.values():
            model.network = Model.from_config(model.config)
            model.network.set_weights(model.weights)

        self.build_autoencoders()
        self.compile_predictors(initialize=False)
        logger.debug("Reset models")

    def compile_predictors(self, initialize=True):
        """ Compile the predictors """
        logger.debug("Compiling Predictors")
        learning_rate = self.config.get("learning_rate", 5e-5)
        optimizer = self.get_optimizer(lr=learning_rate,
                                       beta_1=0.5,
                                       beta_2=0.999)

        for side, model in self.predictors.items():
            mask = [inp for inp in model.inputs if inp.name.startswith("mask")]
            loss_names = ["loss"]
            loss_funcs = [self.loss_function(mask, side, initialize)]
            if mask:
                loss_names.append("mask_loss")
                loss_funcs.append(self.mask_loss_function(side, initialize))
            model.compile(optimizer=optimizer, loss=loss_funcs)

            if len(loss_names) > 1:
                loss_names.insert(0, "total_loss")
            if initialize:
                self.state.add_session_loss_names(side, loss_names)
                self.history[side] = list()
        logger.debug("Compiled Predictors. Losses: %s", loss_names)

    def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999):  # pylint: disable=invalid-name
        """ Build and return Optimizer """
        opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2)
        if (self.config.get("clipnorm", False)
                and keras.backend.backend() != "plaidml.keras.backend"):
            # NB: Clipnorm is ballooning VRAM useage, which is not expected behaviour
            # and may be a bug in Keras/TF.
            # PlaidML has a bug regarding the clipnorm parameter
            # See: https://github.com/plaidml/plaidml/issues/228
            # Workaround by simply removing it.
            # TODO: Remove this as soon it is fixed in PlaidML.
            opt_kwargs["clipnorm"] = 1.0
        logger.debug("Optimizer kwargs: %s", opt_kwargs)
        return Adam(**opt_kwargs)

    def loss_function(self, mask, side, initialize):
        """ Set the loss function
            Side is input so we only log once """
        if self.config.get("dssim_loss", False):
            if side == "a" and not self.predict and initialize:
                logger.verbose("Using DSSIM Loss")
            loss_func = DSSIMObjective()
        else:
            if side == "a" and not self.predict and initialize:
                logger.verbose("Using Mean Absolute Error Loss")
            loss_func = losses.mean_absolute_error

        if mask and self.config.get("penalized_mask_loss", False):
            loss_mask = mask[0]
            if side == "a" and not self.predict and initialize:
                logger.verbose("Penalizing mask for Loss")
            loss_func = PenalizedLoss(loss_mask, loss_func)
        return loss_func

    def mask_loss_function(self, side, initialize):
        """ Set the mask loss function
            Side is input so we only log once """
        if side == "a" and not self.predict and initialize:
            logger.verbose("Using Mean Squared Error Loss for mask")
        mask_loss_func = losses.mean_squared_error
        return mask_loss_func

    def converter(self, swap):
        """ Converter for autoencoder models """
        logger.debug("Getting Converter: (swap: %s)", swap)
        if swap:
            model = self.predictors["a"]
        else:
            model = self.predictors["b"]
        if self.predict:
            # Must compile the model to be thread safe
            model._make_predict_function()  # pylint: disable=protected-access
        retval = model.predict
        logger.debug("Got Converter: %s", retval)
        return retval

    @property
    def iterations(self):
        "Get current training iteration number"
        return self.state.iterations

    def map_models(self, swapped):
        """ Map the models for A/B side for swapping """
        logger.debug("Map models: (swapped: %s)", swapped)
        models_map = {"a": dict(), "b": dict()}
        sides = ("a", "b") if not swapped else ("b", "a")
        for network in self.networks.values():
            if network.side == sides[0]:
                models_map["a"][network.type] = network.filename
            if network.side == sides[1]:
                models_map["b"][network.type] = network.filename
        logger.debug("Mapped models: (models_map: %s)", models_map)
        return models_map

    def log_summary(self):
        """ Verbose log the model summaries """
        if self.predict:
            return
        for side in sorted(list(self.predictors.keys())):
            logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
            self.predictors[side].summary(
                print_fn=lambda x: logger.verbose("R|%s", x))
            for name, nnmeta in self.networks.items():
                if nnmeta.side is not None and nnmeta.side != side:
                    continue
                logger.verbose("%s:", name.title())
                nnmeta.network.summary(
                    print_fn=lambda x: logger.verbose("R|%s", x))

    def do_snapshot(self):
        """ Perform a model snapshot """
        logger.debug("Performing snapshot")
        self.backup.snapshot_models(self.iterations)
        logger.debug("Performed snapshot")

    def load_models(self, swapped):
        """ Load models from file """
        logger.debug("Load model: (swapped: %s)", swapped)

        if not self.models_exist and not self.predict:
            logger.info("Creating new '%s' model in folder: '%s'", self.name,
                        self.model_dir)
            return None
        if not self.models_exist and self.predict:
            logger.error("Model could not be found in folder '%s'. Exiting",
                         self.model_dir)
            exit(0)

        if not self.is_legacy:
            K.clear_session()
        model_mapping = self.map_models(swapped)
        for network in self.networks.values():
            if not network.side:
                is_loaded = network.load()
            else:
                is_loaded = network.load(
                    fullpath=model_mapping[network.side][network.type])
            if not is_loaded:
                break
        if is_loaded:
            logger.info("Loaded model from disk: '%s'", self.model_dir)
        return is_loaded

    def save_models(self):
        """ Backup and save the models """
        logger.debug("Backing up and saving models")
        save_averages = self.get_save_averages()
        backup_func = self.backup.backup_model if self.should_backup(
            save_averages) else None
        if backup_func:
            logger.info("Backing up models...")
        save_threads = list()
        for network in self.networks.values():
            name = "save_{}".format(network.name)
            save_threads.append(
                MultiThread(network.save, name=name, backup_func=backup_func))
        save_threads.append(
            MultiThread(self.state.save,
                        name="save_state",
                        backup_func=backup_func))
        for thread in save_threads:
            thread.start()
        for thread in save_threads:
            if thread.has_error:
                logger.error(thread.errors[0])
            thread.join()
        msg = "[Saved models]"
        if save_averages:
            lossmsg = [
                "{}_{}: {:.5f}".format(self.state.loss_names[side][0],
                                       side.capitalize(), save_averages[side])
                for side in sorted(list(save_averages.keys()))
            ]
            msg += " - Average since last save: {}".format(", ".join(lossmsg))
        logger.info(msg)

    def get_save_averages(self):
        """ Return the average loss since the last save iteration and reset historical loss """
        logger.debug("Getting save averages")
        avgs = dict()
        for side, loss in self.history.items():
            if not loss:
                logger.debug("No loss in self.history: %s", side)
                break
            avgs[side] = sum(loss) / len(loss)
            self.history[side] = list()  # Reset historical loss
        logger.debug("Average losses since last save: %s", avgs)
        return avgs

    def should_backup(self, save_averages):
        """ Check whether the loss averages for all losses is the lowest that has been seen.

            This protects against model corruption by only backing up the model
            if any of the loss values have fallen.
            TODO This is not a perfect system. If the model corrupts on save_iteration - 1
            then model may still backup
        """
        backup = True

        if not save_averages:
            logger.debug("No save averages. Not backing up")
            return

        for side, loss in save_averages.items():
            if not self.state.lowest_avg_loss.get(side, None):
                logger.debug(
                    "Setting initial save iteration loss average for '%s': %s",
                    side, loss)
                self.state.lowest_avg_loss[side] = loss
                continue
            if backup:
                # Only run this if backup is true. All losses must have dropped for a valid backup
                backup = self.check_loss_drop(side, loss)

        logger.debug("Lowest historical save iteration loss average: %s",
                     self.state.lowest_avg_loss)

        if backup:  # Update lowest loss values to the state
            for side, avg_loss in save_averages.items():
                logger.debug(
                    "Updating lowest save iteration average for '%s': %s",
                    side, avg_loss)
                self.state.lowest_avg_loss[side] = avg_loss

        logger.debug("Backing up: %s", backup)
        return backup

    def check_loss_drop(self, side, avg):
        """ Check whether total loss has dropped since lowest loss """
        if avg < self.state.lowest_avg_loss[side]:
            logger.debug("Loss for '%s' has dropped", side)
            return True
        logger.debug("Loss for '%s' has not dropped", side)
        return False

    def rename_legacy(self):
        """ Legacy Original, LowMem and IAE models had inconsistent naming conventions
            Rename them if they are found and update """
        legacy_mapping = {
            "iae": [("IAE_decoder.h5", "iae_decoder.h5"),
                    ("IAE_encoder.h5", "iae_encoder.h5"),
                    ("IAE_inter_A.h5", "iae_intermediate_A.h5"),
                    ("IAE_inter_B.h5", "iae_intermediate_B.h5"),
                    ("IAE_inter_both.h5", "iae_inter.h5")],
            "original": [("encoder.h5", "original_encoder.h5"),
                         ("decoder_A.h5", "original_decoder_A.h5"),
                         ("decoder_B.h5", "original_decoder_B.h5"),
                         ("lowmem_encoder.h5", "original_encoder.h5"),
                         ("lowmem_decoder_A.h5", "original_decoder_A.h5"),
                         ("lowmem_decoder_B.h5", "original_decoder_B.h5")]
        }
        if self.name not in legacy_mapping.keys():
            return
        logger.debug("Renaming legacy files")

        set_lowmem = False
        updated = False
        for old_name, new_name in legacy_mapping[self.name]:
            old_path = os.path.join(str(self.model_dir), old_name)
            new_path = os.path.join(str(self.model_dir), new_name)
            if os.path.exists(old_path) and not os.path.exists(new_path):
                logger.info("Updating legacy model name from: '%s' to '%s'",
                            old_name, new_name)
                os.rename(old_path, new_path)
                if old_name.startswith("lowmem"):
                    set_lowmem = True
                updated = True

        if not updated:
            logger.debug("No legacy files to rename")
            return

        self.is_legacy = True
        logger.debug("Creating state file for legacy model")
        self.state.inputs = {"face:0": [64, 64, 3]}
        self.state.training_size = 256
        self.state.config["coverage"] = 62.5
        self.state.config["subpixel_upscaling"] = False
        self.state.config["reflect_padding"] = False
        self.state.config["mask_type"] = None
        self.state.config["lowmem"] = False
        self.encoder_dim = 1024

        if set_lowmem:
            logger.debug(
                "Setting encoder_dim and lowmem flag for legacy lowmem model")
            self.encoder_dim = 512
            self.state.config["lowmem"] = True

        self.state.replace_config(self.config_changeable_items)
        self.state.save()
示例#7
0
 def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool) -> None:
     self._plugin = plugin
     self._is_predict = is_predict
     self._model_dir = model_dir
     self._history: List[List[float]] = [[], []]  # Loss histories per save iteration
     self._backup = Backup(self._model_dir, self._plugin.name)
示例#8
0
class IO():
    """ Model saving and loading functions.

    Handles the loading and saving of the plugin model from disk as well as the model backup and
    snapshot functions.

    Parameters
    ----------
    plugin: :class:`Model`
        The parent plugin class that owns the IO functions.
    model_dir: str
        The full path to the model save location
    is_predict: bool
        ``True`` if the model is being loaded for inference. ``False`` if the model is being loaded
        for training.
    """
    def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool) -> None:
        self._plugin = plugin
        self._is_predict = is_predict
        self._model_dir = model_dir
        self._history: List[List[float]] = [[], []]  # Loss histories per save iteration
        self._backup = Backup(self._model_dir, self._plugin.name)

    @property
    def _filename(self) -> str:
        """str: The filename for this model."""
        return os.path.join(self._model_dir, f"{self._plugin.name}.h5")

    @property
    def model_exists(self) -> bool:
        """ bool: ``True`` if a model of the type being loaded exists within the model folder
        location otherwise ``False``.
        """
        return os.path.isfile(self._filename)

    @property
    def history(self) -> List[List[float]]:
        """ list: list of loss histories per side for the current save iteration. """
        return self._history

    @property
    def multiple_models_in_folder(self) -> Optional[List[str]]:
        """ :list: or ``None`` If there are multiple model types in the requested folder, or model
        types that don't correspond to the requested plugin type, then returns the list of plugin
        names that exist in the folder, otherwise returns ``None`` """
        plugins = [fname.replace(".h5", "")
                   for fname in os.listdir(self._model_dir)
                   if fname.endswith(".h5")]
        test_names = plugins + [self._plugin.name]
        test = False if not test_names else os.path.commonprefix(test_names) == ""
        retval = None if not test else plugins
        logger.debug("plugin name: %s, plugins: %s, test result: %s, retval: %s",
                     self._plugin.name, plugins, test, retval)
        return retval

    def _load(self) -> keras.models.Model:
        """ Loads the model from disk

        If the predict function is to be called and the model cannot be found in the model folder
        then an error is logged and the process exits.

        When loading the model, the plugin model folder is scanned for custom layers which are
        added to Keras' custom objects.

        Returns
        -------
        :class:`keras.models.Model`
            The saved model loaded from disk
        """
        logger.debug("Loading model: %s", self._filename)
        if self._is_predict and not self.model_exists:
            logger.error("Model could not be found in folder '%s'. Exiting", self._model_dir)
            sys.exit(1)

        try:
            model = load_model(self._filename, compile=False)
        except RuntimeError as err:
            if "unable to get link info" in str(err).lower():
                msg = (f"Unable to load the model from '{self._filename}'. This may be a "
                       "temporary error but most likely means that your model has corrupted.\n"
                       "You can try to load the model again but if the problem persists you "
                       "should use the Restore Tool to restore your model from backup.\n"
                       f"Original error: {str(err)}")
                raise FaceswapError(msg) from err
            raise err
        except KeyError as err:
            if "unable to open object" in str(err).lower():
                msg = (f"Unable to load the model from '{self._filename}'. This may be a "
                       "temporary error but most likely means that your model has corrupted.\n"
                       "You can try to load the model again but if the problem persists you "
                       "should use the Restore Tool to restore your model from backup.\n"
                       f"Original error: {str(err)}")
                raise FaceswapError(msg) from err
            raise err

        logger.info("Loaded model from disk: '%s'", self._filename)
        return model

    def save(self) -> None:
        """ Backup and save the model and state file.

        Notes
        -----
        The backup function actually backups the model from the previous save iteration rather than
        the current save iteration. This is not a bug, but protection against long save times, as
        models can get quite large, so renaming the current model file rather than copying it can
        save substantial amount of time.
        """
        logger.debug("Backing up and saving models")
        print("")  # Insert a new line to avoid spamming the same row as loss output
        save_averages = self._get_save_averages()
        if save_averages and self._should_backup(save_averages):
            self._backup.backup_model(self._filename)
            # pylint:disable=protected-access
            self._backup.backup_model(self._plugin.state._filename)

        self._plugin.model.save(self._filename, include_optimizer=False)
        self._plugin.state.save()

        msg = "[Saved models]"
        if save_averages:
            lossmsg = [f"face_{side}: {avg:.5f}"
                       for side, avg in zip(("a", "b"), save_averages)]
            msg += f" - Average loss since last save: {', '.join(lossmsg)}"
        logger.info(msg)

    def _get_save_averages(self) -> List[float]:
        """ Return the average loss since the last save iteration and reset historical loss """
        logger.debug("Getting save averages")
        if not all(loss for loss in self._history):
            logger.debug("No loss in history")
            retval = []
        else:
            retval = [sum(loss) / len(loss) for loss in self._history]
            self._history = [[], []]  # Reset historical loss
        logger.debug("Average losses since last save: %s", retval)
        return retval

    def _should_backup(self, save_averages: List[float]) -> bool:
        """ Check whether the loss averages for this save iteration is the lowest that has been
        seen.

        This protects against model corruption by only backing up the model if both sides have
        seen a total fall in loss.

        Notes
        -----
        This is by no means a perfect system. If the model corrupts at an iteration close
        to a save iteration, then the averages may still be pushed lower than a previous
        save average, resulting in backing up a corrupted model.

        Parameters
        ----------
        save_averages: list
            The average loss for each side for this save iteration
        """
        backup = True
        for side, loss in zip(("a", "b"), save_averages):
            if not self._plugin.state.lowest_avg_loss.get(side, None):
                logger.debug("Set initial save iteration loss average for '%s': %s", side, loss)
                self._plugin.state.lowest_avg_loss[side] = loss
                continue
            backup = loss < self._plugin.state.lowest_avg_loss[side] if backup else backup

        if backup:  # Update lowest loss values to the state file
            # pylint:disable=unnecessary-comprehension
            old_avgs = {key: val for key, val in self._plugin.state.lowest_avg_loss.items()}
            self._plugin.state.lowest_avg_loss["a"] = save_averages[0]
            self._plugin.state.lowest_avg_loss["b"] = save_averages[1]
            logger.debug("Updated lowest historical save iteration averages from: %s to: %s",
                         old_avgs, self._plugin.state.lowest_avg_loss)

        logger.debug("Should backup: %s", backup)
        return backup

    def snapshot(self) -> None:
        """ Perform a model snapshot.

        Notes
        -----
        Snapshot function is called 1 iteration after the model was saved, so that it is built from
        the latest save, hence iteration being reduced by 1.
        """
        logger.debug("Performing snapshot. Iterations: %s", self._plugin.iterations)
        self._backup.snapshot_models(self._plugin.iterations - 1)
        logger.debug("Performed snapshot")