def process(self): """ Perform the Restore process """ logger.info("Starting Model Restore...") self.validate() backup = Backup(self.model_dir, self.model_name) backup.restore() logger.info("Completed Model Restore")
def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool, save_optimizer: Literal["never", "always", "exit"]) -> None: self._plugin = plugin self._is_predict = is_predict self._model_dir = model_dir self._save_optimizer = save_optimizer self._history: List[List[float]] = [ [], [] ] # Loss histories per save iteration self._backup = Backup(self._model_dir, self._plugin.name)
def __init__(self, model_dir, gpus=1, configfile=None, snapshot_interval=0, no_logs=False, warp_to_landmarks=False, augment_color=True, no_flip=False, training_image_size=256, alignments_paths=None, preview_scale=100, input_shape=None, encoder_dim=None, trainer="original", pingpong=False, memory_saving_gradients=False, optimizer_savings=False, predict=False): logger.debug( "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, " "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: " "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, " "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, " "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, " "predict: %s)", self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval, no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size, alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong, memory_saving_gradients, optimizer_savings, predict) self.predict = predict self.model_dir = model_dir self.vram_savings = VRAMSavings(pingpong, optimizer_savings, memory_saving_gradients) self.backup = Backup(self.model_dir, self.name) self.gpus = gpus self.configfile = configfile self.input_shape = input_shape self.encoder_dim = encoder_dim self.trainer = trainer self.load_config( ) # Load config if plugin has not already referenced it self.state = State(self.model_dir, self.name, self.config_changeable_items, no_logs, self.vram_savings.pingpong, training_image_size) self.blocks = NNBlocks( use_subpixel=self.config["subpixel_upscaling"], use_icnr_init=self.config["icnr_init"], use_convaware_init=self.config["conv_aware_init"], use_reflect_padding=self.config["reflect_padding"], first_run=self.state.first_run) self.is_legacy = False self.rename_legacy() self.load_state_info() self.networks = dict() # Networks for the model self.predictors = dict() # Predictors for model self.history = dict() # Loss history per save iteration) # Training information specific to the model should be placed in this # dict for reference by the trainer. self.training_opts = { "alignments": alignments_paths, "preview_scaling": preview_scale / 100, "warp_to_landmarks": warp_to_landmarks, "augment_color": augment_color, "no_flip": no_flip, "pingpong": self.vram_savings.pingpong, "snapshot_interval": snapshot_interval, "training_size": self.state.training_size, "no_logs": self.state.current_session["no_logs"], "coverage_ratio": self.calculate_coverage_ratio(), "mask_type": self.config["mask_type"], "mask_blur_kernel": self.config["mask_blur_kernel"], "mask_threshold": self.config["mask_threshold"], "learn_mask": (self.config["learn_mask"] and self.config["mask_type"] is not None), "penalized_mask_loss": (self.config["penalized_mask_loss"] and self.config["mask_type"] is not None) } logger.debug("training_opts: %s", self.training_opts) if self.multiple_models_in_folder: deprecation_warning( "Support for multiple model types within the same folder", additional_info= "Please split each model into separate folders to " "avoid issues in future.") self.build() logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
class ModelBase(): """ Base class that all models should inherit from """ def __init__(self, model_dir, gpus=1, configfile=None, snapshot_interval=0, no_logs=False, warp_to_landmarks=False, augment_color=True, no_flip=False, training_image_size=256, alignments_paths=None, preview_scale=100, input_shape=None, encoder_dim=None, trainer="original", pingpong=False, memory_saving_gradients=False, optimizer_savings=False, predict=False): logger.debug( "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, " "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: " "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, " "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, " "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, " "predict: %s)", self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval, no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size, alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong, memory_saving_gradients, optimizer_savings, predict) self.predict = predict self.model_dir = model_dir self.vram_savings = VRAMSavings(pingpong, optimizer_savings, memory_saving_gradients) self.backup = Backup(self.model_dir, self.name) self.gpus = gpus self.configfile = configfile self.input_shape = input_shape self.encoder_dim = encoder_dim self.trainer = trainer self.load_config( ) # Load config if plugin has not already referenced it self.state = State(self.model_dir, self.name, self.config_changeable_items, no_logs, self.vram_savings.pingpong, training_image_size) self.blocks = NNBlocks( use_subpixel=self.config["subpixel_upscaling"], use_icnr_init=self.config["icnr_init"], use_convaware_init=self.config["conv_aware_init"], use_reflect_padding=self.config["reflect_padding"], first_run=self.state.first_run) self.is_legacy = False self.rename_legacy() self.load_state_info() self.networks = dict() # Networks for the model self.predictors = dict() # Predictors for model self.history = dict() # Loss history per save iteration) # Training information specific to the model should be placed in this # dict for reference by the trainer. self.training_opts = { "alignments": alignments_paths, "preview_scaling": preview_scale / 100, "warp_to_landmarks": warp_to_landmarks, "augment_color": augment_color, "no_flip": no_flip, "pingpong": self.vram_savings.pingpong, "snapshot_interval": snapshot_interval, "training_size": self.state.training_size, "no_logs": self.state.current_session["no_logs"], "coverage_ratio": self.calculate_coverage_ratio(), "mask_type": self.config["mask_type"], "mask_blur_kernel": self.config["mask_blur_kernel"], "mask_threshold": self.config["mask_threshold"], "learn_mask": (self.config["learn_mask"] and self.config["mask_type"] is not None), "penalized_mask_loss": (self.config["penalized_mask_loss"] and self.config["mask_type"] is not None) } logger.debug("training_opts: %s", self.training_opts) if self.multiple_models_in_folder: deprecation_warning( "Support for multiple model types within the same folder", additional_info= "Please split each model into separate folders to " "avoid issues in future.") self.build() logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) @property def config_section(self): """ The section name for loading config """ retval = ".".join(self.__module__.split(".")[-2:]) logger.debug(retval) return retval @property def config(self): """ Return config dict for current plugin """ global _CONFIG # pylint: disable=global-statement if not _CONFIG: model_name = self.config_section logger.debug("Loading config for: %s", model_name) _CONFIG = Config(model_name, configfile=self.configfile).config_dict return _CONFIG @property def config_changeable_items(self): """ Return the dict of config items that can be updated after the model has been created """ return Config(self.config_section, configfile=self.configfile).changeable_items @property def name(self): """ Set the model name based on the subclass """ basename = os.path.basename(sys.modules[self.__module__].__file__) retval = os.path.splitext(basename)[0].lower() logger.debug("model name: '%s'", retval) return retval @property def models_exist(self): """ Return if all files exist and clear session """ retval = all([ os.path.isfile(model.filename) for model in self.networks.values() ]) logger.debug("Pre-existing models exist: %s", retval) return retval @property def multiple_models_in_folder(self): """ Return true if there are multiple model types in the same folder, else false """ model_files = [ fname for fname in os.listdir(str(self.model_dir)) if fname.endswith(".h5") ] retval = False if not model_files else os.path.commonprefix( model_files) == "" logger.debug("model_files: %s, retval: %s", model_files, retval) return retval @property def output_shapes(self): """ Return the output shapes from the main AutoEncoder """ out = list() for predictor in self.predictors.values(): out.extend( [K.int_shape(output)[-3:] for output in predictor.outputs]) break # Only get output from one autoencoder. Shapes are the same return [tuple(shape) for shape in out] @property def output_shape(self): """ The output shape of the model (shape of largest face output) """ return self.output_shapes[self.largest_face_index] @property def largest_face_index(self): """ Return the index from model.outputs of the largest face Required for multi-output model prediction. The largest face is assumed to be the final output """ sizes = [shape[1] for shape in self.output_shapes if shape[2] == 3] if not sizes: return None max_face = max(sizes) retval = [ idx for idx, shape in enumerate(self.output_shapes) if shape[1] == max_face and shape[2] == 3 ][0] logger.debug(retval) return retval @property def largest_mask_index(self): """ Return the index from model.outputs of the largest mask Required for multi-output model prediction. The largest face is assumed to be the final output """ sizes = [shape[1] for shape in self.output_shapes if shape[2] == 1] if not sizes: return None max_mask = max(sizes) retval = [ idx for idx, shape in enumerate(self.output_shapes) if shape[1] == max_mask and shape[2] == 1 ][0] logger.debug(retval) return retval @property def feed_mask(self): """ bool: ``True`` if the model expects a mask to be fed into input otherwise ``False`` """ return self.config["mask_type"] is not None and ( self.config["learn_mask"] or self.config["penalized_mask_loss"]) def load_config(self): """ Load the global config for reference in self.config """ global _CONFIG # pylint: disable=global-statement if not _CONFIG: model_name = self.config_section logger.debug("Loading config for: %s", model_name) _CONFIG = Config(model_name, configfile=self.configfile).config_dict def calculate_coverage_ratio(self): """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """ coverage_ratio = self.config.get("coverage", 62.5) / 100 logger.debug("Requested coverage_ratio: %s", coverage_ratio) cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2 coverage_ratio = cropped_size / self.state.training_size logger.debug("Final coverage_ratio: %s", coverage_ratio) return coverage_ratio def build(self): """ Build the model. Override for custom build methods """ self.add_networks() self.load_models(swapped=False) inputs = self.get_inputs() try: self.build_autoencoders(inputs) except ValueError as err: if "must be from the same graph" in str(err).lower(): msg = ( "There was an error loading saved weights. This is most likely due to " "model corruption during a previous save." "\nYou should restore weights from a snapshot or from backup files. " "You can use the 'Restore' Tool to restore from backup.") raise FaceswapError(msg) from err if "multi_gpu_model" in str(err).lower(): raise FaceswapError(str(err)) from err raise err self.log_summary() self.compile_predictors(initialize=True) def get_inputs(self): """ Return the inputs for the model """ logger.debug("Getting inputs") inputs = [Input(shape=self.input_shape, name="face_in")] output_network = [ network for network in self.networks.values() if network.is_output ][0] if self.feed_mask: # TODO penalized mask doesn't have a mask output, so we can't use output shapes # mask should always be last output..this needs to be a rule mask_shape = output_network.output_shapes[-1] inputs.append( Input(shape=(mask_shape[1:-1] + (1, )), name="mask_in")) logger.debug("Got inputs: %s", inputs) return inputs def build_autoencoders(self, inputs): """ Override for Model Specific autoencoder builds Inputs is defined in self.get_inputs() and is standardized for all models if will generally be in the order: [face (the input for image), mask (the input for mask if it is used)] """ raise NotImplementedError def add_networks(self): """ Override to add neural networks """ raise NotImplementedError def load_state_info(self): """ Load the input shape from state file if it exists """ logger.debug("Loading Input Shape from State file") if not self.state.inputs: logger.debug("No input shapes saved. Using model config") return if not self.state.face_shapes: logger.warning( "Input shapes stored in State file, but no matches for 'face'." "Using model config") return input_shape = self.state.face_shapes[0] logger.debug("Setting input shape from state file: %s", input_shape) self.input_shape = input_shape def add_network(self, network_type, side, network, is_output=False): """ Add a NNMeta object """ logger.debug( "network_type: '%s', side: '%s', network: '%s', is_output: %s", network_type, side, network, is_output) filename = "{}_{}".format(self.name, network_type.lower()) name = network_type.lower() if side: side = side.lower() filename += "_{}".format(side.upper()) name += "_{}".format(side) filename += ".h5" logger.debug("name: '%s', filename: '%s'", name, filename) self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network, is_output) def add_predictor(self, side, model): """ Add a predictor to the predictors dictionary """ logger.debug("Adding predictor: (side: '%s', model: %s)", side, model) if self.gpus > 1: logger.debug("Converting to multi-gpu: side %s", side) model = multi_gpu_model(model, self.gpus) self.predictors[side] = model if not self.state.inputs: self.store_input_shapes(model) def store_input_shapes(self, model): """ Store the input and output shapes to state """ logger.debug("Adding input shapes to state for model") inputs = { tensor.name: K.int_shape(tensor)[-3:] for tensor in model.inputs } if not any(inp for inp in inputs.keys() if inp.startswith("face")): raise ValueError( "No input named 'face' was found. Check your input naming. " "Current input names: {}".format(inputs)) # Make sure they are all ints so that it can be json serialized inputs = { key: tuple(int(i) for i in val) for key, val in inputs.items() } self.state.inputs = inputs logger.debug("Added input shapes: %s", self.state.inputs) def reset_pingpong(self): """ Reset the models for pingpong training """ logger.debug("Resetting models") # Clear models and graph self.predictors = dict() K.clear_session() # Load Models for current training run for model in self.networks.values(): model.network = Model.from_config(model.config) model.network.set_weights(model.weights) inputs = self.get_inputs() self.build_autoencoders(inputs) self.compile_predictors(initialize=False) logger.debug("Reset models") def compile_predictors(self, initialize=True): """ Compile the predictors """ logger.debug("Compiling Predictors") learning_rate = self.config.get("learning_rate", 5e-5) optimizer = self.get_optimizer(lr=learning_rate, beta_1=0.5, beta_2=0.999) for side, model in self.predictors.items(): loss = Loss(model.inputs, model.outputs) model.compile(optimizer=optimizer, loss=loss.funcs) if initialize: self.state.add_session_loss_names(side, loss.names) self.history[side] = list() logger.debug("Compiled Predictors. Losses: %s", loss.names) def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=invalid-name """ Build and return Optimizer """ opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2) if (self.config.get("clipnorm", False) and keras.backend.backend() != "plaidml.keras.backend"): # NB: Clipnorm is ballooning VRAM usage, which is not expected behavior # and may be a bug in Keras/TF. # PlaidML has a bug regarding the clipnorm parameter # See: https://github.com/plaidml/plaidml/issues/228 # Workaround by simply removing it. # TODO: Remove this as soon it is fixed in PlaidML. opt_kwargs["clipnorm"] = 1.0 logger.debug("Optimizer kwargs: %s", opt_kwargs) return Adam(**opt_kwargs, cpu_mode=self.vram_savings.optimizer_savings) def converter(self, swap): """ Converter for autoencoder models """ logger.debug("Getting Converter: (swap: %s)", swap) side = "a" if swap else "b" model = self.predictors[side] if self.predict: # Must compile the model to be thread safe model._make_predict_function() # pylint: disable=protected-access retval = model.predict logger.debug("Got Converter: %s", retval) return retval @property def iterations(self): "Get current training iteration number" return self.state.iterations def map_models(self, swapped): """ Map the models for A/B side for swapping """ logger.debug("Map models: (swapped: %s)", swapped) models_map = {"a": dict(), "b": dict()} sides = ("a", "b") if not swapped else ("b", "a") for network in self.networks.values(): if network.side == sides[0]: models_map["a"][network.type] = network.filename if network.side == sides[1]: models_map["b"][network.type] = network.filename logger.debug("Mapped models: (models_map: %s)", models_map) return models_map def log_summary(self): """ Verbose log the model summaries """ if self.predict: return for side in sorted(list(self.predictors.keys())): logger.verbose("[%s %s Summary]:", self.name.title(), side.upper()) self.predictors[side].summary( print_fn=lambda x: logger.verbose("%s", x)) for name, nnmeta in self.networks.items(): if nnmeta.side is not None and nnmeta.side != side: continue logger.verbose("%s:", name.title()) nnmeta.network.summary( print_fn=lambda x: logger.verbose("%s", x)) def do_snapshot(self): """ Perform a model snapshot """ logger.debug("Performing snapshot") self.backup.snapshot_models(self.iterations) logger.debug("Performed snapshot") def load_models(self, swapped): """ Load models from file """ logger.debug("Load model: (swapped: %s)", swapped) if not self.models_exist and not self.predict: logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir) return None if not self.models_exist and self.predict: logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir) exit(0) if not self.is_legacy or not self.predict: K.clear_session() model_mapping = self.map_models(swapped) for network in self.networks.values(): if not network.side: is_loaded = network.load() else: is_loaded = network.load( fullpath=model_mapping[network.side][network.type]) if not is_loaded: break if is_loaded: logger.info("Loaded model from disk: '%s'", self.model_dir) return is_loaded def save_models(self): """ Backup and save the models """ logger.debug("Backing up and saving models") save_averages = self.get_save_averages() backup_func = self.backup.backup_model if self.should_backup( save_averages) else None if backup_func: logger.info("Backing up models...") executor = futures.ThreadPoolExecutor() save_threads = [ executor.submit(network.save, backup_func=backup_func) for network in self.networks.values() ] save_threads.append( executor.submit(self.state.save, backup_func=backup_func)) futures.wait(save_threads) # call result() to capture errors _ = [thread.result() for thread in save_threads] msg = "[Saved models]" if save_averages: lossmsg = [ "{}_{}: {:.5f}".format(self.state.loss_names[side][0], side.capitalize(), save_averages[side]) for side in sorted(list(save_averages.keys())) ] msg += " - Average since last save: {}".format(", ".join(lossmsg)) logger.info(msg) def get_save_averages(self): """ Return the average loss since the last save iteration and reset historical loss """ logger.debug("Getting save averages") avgs = dict() for side, loss in self.history.items(): if not loss: logger.debug("No loss in self.history: %s", side) break avgs[side] = sum(loss) / len(loss) self.history[side] = list() # Reset historical loss logger.debug("Average losses since last save: %s", avgs) return avgs def should_backup(self, save_averages): """ Check whether the loss averages for all losses is the lowest that has been seen. This protects against model corruption by only backing up the model if any of the loss values have fallen. TODO This is not a perfect system. If the model corrupts on save_iteration - 1 then model may still backup """ backup = True if not save_averages: logger.debug("No save averages. Not backing up") return False for side, loss in save_averages.items(): if not self.state.lowest_avg_loss.get(side, None): logger.debug( "Setting initial save iteration loss average for '%s': %s", side, loss) self.state.lowest_avg_loss[side] = loss continue if backup: # Only run this if backup is true. All losses must have dropped for a valid backup backup = self.check_loss_drop(side, loss) logger.debug("Lowest historical save iteration loss average: %s", self.state.lowest_avg_loss) if backup: # Update lowest loss values to the state for side, avg_loss in save_averages.items(): logger.debug( "Updating lowest save iteration average for '%s': %s", side, avg_loss) self.state.lowest_avg_loss[side] = avg_loss logger.debug("Backing up: %s", backup) return backup def check_loss_drop(self, side, avg): """ Check whether total loss has dropped since lowest loss """ if avg < self.state.lowest_avg_loss[side]: logger.debug("Loss for '%s' has dropped", side) return True logger.debug("Loss for '%s' has not dropped", side) return False def rename_legacy(self): """ Legacy Original, LowMem and IAE models had inconsistent naming conventions Rename them if they are found and update """ legacy_mapping = { "iae": [("IAE_decoder.h5", "iae_decoder.h5"), ("IAE_encoder.h5", "iae_encoder.h5"), ("IAE_inter_A.h5", "iae_intermediate_A.h5"), ("IAE_inter_B.h5", "iae_intermediate_B.h5"), ("IAE_inter_both.h5", "iae_inter.h5")], "original": [("encoder.h5", "original_encoder.h5"), ("decoder_A.h5", "original_decoder_A.h5"), ("decoder_B.h5", "original_decoder_B.h5"), ("lowmem_encoder.h5", "original_encoder.h5"), ("lowmem_decoder_A.h5", "original_decoder_A.h5"), ("lowmem_decoder_B.h5", "original_decoder_B.h5")] } if self.name not in legacy_mapping.keys(): return logger.debug("Renaming legacy files") set_lowmem = False updated = False for old_name, new_name in legacy_mapping[self.name]: old_path = os.path.join(str(self.model_dir), old_name) new_path = os.path.join(str(self.model_dir), new_name) if os.path.exists(old_path) and not os.path.exists(new_path): logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name) os.rename(old_path, new_path) if old_name.startswith("lowmem"): set_lowmem = True updated = True if not updated: logger.debug("No legacy files to rename") return self.is_legacy = True logger.debug("Creating state file for legacy model") self.state.inputs = {"face:0": [64, 64, 3]} self.state.training_size = 256 self.state.config["coverage"] = 62.5 self.state.config["subpixel_upscaling"] = False self.state.config["reflect_padding"] = False self.state.config["mask_type"] = None self.state.config["mask_blur_kernel"] = 3 self.state.config["mask_threshold"] = 4 self.state.config["learn_mask"] = False self.state.config["lowmem"] = False self.encoder_dim = 1024 if set_lowmem: logger.debug( "Setting encoder_dim and lowmem flag for legacy lowmem model") self.encoder_dim = 512 self.state.config["lowmem"] = True self.state.replace_config(self.config_changeable_items) self.state.save()
def __init__(self, model_dir, gpus, configfile=None, snapshot_interval=0, no_logs=False, warp_to_landmarks=False, augment_color=True, no_flip=False, training_image_size=256, alignments_paths=None, preview_scale=100, input_shape=None, encoder_dim=None, trainer="original", pingpong=False, pretrain=False, memory_saving_gradients=False, predict=False): logger.debug( "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, " "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: " "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, " "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, " "pingpong: %s, memory_saving_gradients: %s, predict: %s)", self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval, no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size, alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong, memory_saving_gradients, predict) self.predict = predict self.model_dir = model_dir self.backup = Backup(self.model_dir, self.name) self.gpus = gpus self.configfile = configfile self.blocks = NNBlocks( use_subpixel=self.config["subpixel_upscaling"], use_icnr_init=self.config["icnr_init"], use_reflect_padding=self.config["reflect_padding"]) self.input_shape = input_shape self.output_shape = None # set after model is compiled self.encoder_dim = encoder_dim self.trainer = trainer self.state = State(self.model_dir, self.name, self.config_changeable_items, no_logs, pingpong, training_image_size) self.is_legacy = False self.rename_legacy() self.load_state_info() self.networks = dict() # Networks for the model self.predictors = dict() # Predictors for model self.history = dict() # Loss history per save iteration) # Training information specific to the model should be placed in this # dict for reference by the trainer. self.training_opts = { "alignments": alignments_paths, "preview_scaling": preview_scale / 100, "warp_to_landmarks": warp_to_landmarks, "augment_color": augment_color, "no_flip": no_flip, "pingpong": pingpong, "snapshot_interval": snapshot_interval } self.set_gradient_type(memory_saving_gradients) self.build() self.set_training_data() logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)
class ModelBase(): """ Base class that all models should inherit from """ def __init__(self, model_dir, gpus, configfile=None, snapshot_interval=0, no_logs=False, warp_to_landmarks=False, augment_color=True, no_flip=False, training_image_size=256, alignments_paths=None, preview_scale=100, input_shape=None, encoder_dim=None, trainer="original", pingpong=False, pretrain=False, memory_saving_gradients=False, predict=False): logger.debug( "Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, " "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: " "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, " "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, " "pingpong: %s, memory_saving_gradients: %s, predict: %s)", self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval, no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size, alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong, memory_saving_gradients, predict) self.predict = predict self.model_dir = model_dir self.backup = Backup(self.model_dir, self.name) self.gpus = gpus self.configfile = configfile self.blocks = NNBlocks( use_subpixel=self.config["subpixel_upscaling"], use_icnr_init=self.config["icnr_init"], use_reflect_padding=self.config["reflect_padding"]) self.input_shape = input_shape self.output_shape = None # set after model is compiled self.encoder_dim = encoder_dim self.trainer = trainer self.state = State(self.model_dir, self.name, self.config_changeable_items, no_logs, pingpong, training_image_size) self.is_legacy = False self.rename_legacy() self.load_state_info() self.networks = dict() # Networks for the model self.predictors = dict() # Predictors for model self.history = dict() # Loss history per save iteration) # Training information specific to the model should be placed in this # dict for reference by the trainer. self.training_opts = { "alignments": alignments_paths, "preview_scaling": preview_scale / 100, "warp_to_landmarks": warp_to_landmarks, "augment_color": augment_color, "no_flip": no_flip, "pingpong": pingpong, "snapshot_interval": snapshot_interval } self.set_gradient_type(memory_saving_gradients) self.build() self.set_training_data() logger.debug("Initialized ModelBase (%s)", self.__class__.__name__) @property def config_section(self): """ The section name for loading config """ retval = ".".join(self.__module__.split(".")[-2:]) logger.debug(retval) return retval @property def config(self): """ Return config dict for current plugin """ global _CONFIG # pylint: disable=global-statement if not _CONFIG: model_name = self.config_section logger.debug("Loading config for: %s", model_name) _CONFIG = Config(model_name, configfile=self.configfile).config_dict return _CONFIG @property def config_changeable_items(self): """ Return the dict of config items that can be updated after the model has been created """ return Config(self.config_section, configfile=self.configfile).changeable_items @property def name(self): """ Set the model name based on the subclass """ basename = os.path.basename(sys.modules[self.__module__].__file__) retval = os.path.splitext(basename)[0].lower() logger.debug("model name: '%s'", retval) return retval @property def models_exist(self): """ Return if all files exist and clear session """ retval = all([ os.path.isfile(model.filename) for model in self.networks.values() ]) logger.debug("Pre-existing models exist: %s", retval) return retval @staticmethod def set_gradient_type(memory_saving_gradients): """ Monkeypatch Memory Saving Gradients if requested """ if not memory_saving_gradients: return logger.info("Using Memory Saving Gradients") from lib.model import memory_saving_gradients K.__dict__["gradients"] = memory_saving_gradients.gradients_memory def set_training_data(self): """ Override to set model specific training data. super() this method for defaults otherwise be sure to add """ logger.debug("Setting training data") # Force number of preview images to between 2 and 16 self.training_opts["training_size"] = self.state.training_size self.training_opts["no_logs"] = self.state.current_session["no_logs"] self.training_opts["mask_type"] = self.config.get("mask_type", None) self.training_opts["coverage_ratio"] = self.calculate_coverage_ratio() logger.debug("Set training data: %s", self.training_opts) def calculate_coverage_ratio(self): """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """ coverage_ratio = self.config.get("coverage", 62.5) / 100 logger.debug("Requested coverage_ratio: %s", coverage_ratio) cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2 coverage_ratio = cropped_size / self.state.training_size logger.debug("Final coverage_ratio: %s", coverage_ratio) return coverage_ratio def build(self): """ Build the model. Override for custom build methods """ self.add_networks() self.load_models(swapped=False) try: self.build_autoencoders() except ValueError as err: if "must be from the same graph" in str(err).lower(): msg = ( "There was an error loading saved weights. This is most likely due to " "model corruption during a previous save." "\nYou should restore weights from a snapshot or from backup files. " "You can use the 'Restore' Tool to restore from backup.") raise FaceswapError(msg) from err self.log_summary() self.compile_predictors(initialize=True) def build_autoencoders(self): """ Override for Model Specific autoencoder builds NB! ENSURE YOU NAME YOUR INPUTS. At least the following input names are expected: face (the input for image) mask (the input for mask if it is used) """ raise NotImplementedError def add_networks(self): """ Override to add neural networks """ raise NotImplementedError def load_state_info(self): """ Load the input shape from state file if it exists """ logger.debug("Loading Input Shape from State file") if not self.state.inputs: logger.debug("No input shapes saved. Using model config") return if not self.state.face_shapes: logger.warning( "Input shapes stored in State file, but no matches for 'face'." "Using model config") return input_shape = self.state.face_shapes[0] logger.debug("Setting input shape from state file: %s", input_shape) self.input_shape = input_shape def add_network(self, network_type, side, network): """ Add a NNMeta object """ logger.debug("network_type: '%s', side: '%s', network: '%s'", network_type, side, network) filename = "{}_{}".format(self.name, network_type.lower()) name = network_type.lower() if side: side = side.lower() filename += "_{}".format(side.upper()) name += "_{}".format(side) filename += ".h5" logger.debug("name: '%s', filename: '%s'", name, filename) self.networks[name] = NNMeta(str(self.model_dir / filename), network_type, side, network) def add_predictor(self, side, model): """ Add a predictor to the predictors dictionary """ logger.debug("Adding predictor: (side: '%s', model: %s)", side, model) if self.gpus > 1: logger.debug("Converting to multi-gpu: side %s", side) model = multi_gpu_model(model, self.gpus) self.predictors[side] = model if not self.state.inputs: self.store_input_shapes(model) if not self.output_shape: self.set_output_shape(model) def store_input_shapes(self, model): """ Store the input and output shapes to state """ logger.debug("Adding input shapes to state for model") inputs = { tensor.name: K.int_shape(tensor)[-3:] for tensor in model.inputs } if not any(inp for inp in inputs.keys() if inp.startswith("face")): raise ValueError( "No input named 'face' was found. Check your input naming. " "Current input names: {}".format(inputs)) self.state.inputs = inputs logger.debug("Added input shapes: %s", self.state.inputs) def set_output_shape(self, model): """ Set the output shape for use in training and convert """ logger.debug("Setting output shape") out = [K.int_shape(tensor)[-3:] for tensor in model.outputs] if not out: raise ValueError("No outputs found! Check your model.") self.output_shape = tuple(out[0]) logger.debug("Added output shape: %s", self.output_shape) def reset_pingpong(self): """ Reset the models for pingpong training """ logger.debug("Resetting models") # Clear models and graph self.predictors = dict() K.clear_session() # Load Models for current training run for model in self.networks.values(): model.network = Model.from_config(model.config) model.network.set_weights(model.weights) self.build_autoencoders() self.compile_predictors(initialize=False) logger.debug("Reset models") def compile_predictors(self, initialize=True): """ Compile the predictors """ logger.debug("Compiling Predictors") learning_rate = self.config.get("learning_rate", 5e-5) optimizer = self.get_optimizer(lr=learning_rate, beta_1=0.5, beta_2=0.999) for side, model in self.predictors.items(): mask = [inp for inp in model.inputs if inp.name.startswith("mask")] loss_names = ["loss"] loss_funcs = [self.loss_function(mask, side, initialize)] if mask: loss_names.append("mask_loss") loss_funcs.append(self.mask_loss_function(side, initialize)) model.compile(optimizer=optimizer, loss=loss_funcs) if len(loss_names) > 1: loss_names.insert(0, "total_loss") if initialize: self.state.add_session_loss_names(side, loss_names) self.history[side] = list() logger.debug("Compiled Predictors. Losses: %s", loss_names) def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999): # pylint: disable=invalid-name """ Build and return Optimizer """ opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2) if (self.config.get("clipnorm", False) and keras.backend.backend() != "plaidml.keras.backend"): # NB: Clipnorm is ballooning VRAM useage, which is not expected behaviour # and may be a bug in Keras/TF. # PlaidML has a bug regarding the clipnorm parameter # See: https://github.com/plaidml/plaidml/issues/228 # Workaround by simply removing it. # TODO: Remove this as soon it is fixed in PlaidML. opt_kwargs["clipnorm"] = 1.0 logger.debug("Optimizer kwargs: %s", opt_kwargs) return Adam(**opt_kwargs) def loss_function(self, mask, side, initialize): """ Set the loss function Side is input so we only log once """ if self.config.get("dssim_loss", False): if side == "a" and not self.predict and initialize: logger.verbose("Using DSSIM Loss") loss_func = DSSIMObjective() else: if side == "a" and not self.predict and initialize: logger.verbose("Using Mean Absolute Error Loss") loss_func = losses.mean_absolute_error if mask and self.config.get("penalized_mask_loss", False): loss_mask = mask[0] if side == "a" and not self.predict and initialize: logger.verbose("Penalizing mask for Loss") loss_func = PenalizedLoss(loss_mask, loss_func) return loss_func def mask_loss_function(self, side, initialize): """ Set the mask loss function Side is input so we only log once """ if side == "a" and not self.predict and initialize: logger.verbose("Using Mean Squared Error Loss for mask") mask_loss_func = losses.mean_squared_error return mask_loss_func def converter(self, swap): """ Converter for autoencoder models """ logger.debug("Getting Converter: (swap: %s)", swap) if swap: model = self.predictors["a"] else: model = self.predictors["b"] if self.predict: # Must compile the model to be thread safe model._make_predict_function() # pylint: disable=protected-access retval = model.predict logger.debug("Got Converter: %s", retval) return retval @property def iterations(self): "Get current training iteration number" return self.state.iterations def map_models(self, swapped): """ Map the models for A/B side for swapping """ logger.debug("Map models: (swapped: %s)", swapped) models_map = {"a": dict(), "b": dict()} sides = ("a", "b") if not swapped else ("b", "a") for network in self.networks.values(): if network.side == sides[0]: models_map["a"][network.type] = network.filename if network.side == sides[1]: models_map["b"][network.type] = network.filename logger.debug("Mapped models: (models_map: %s)", models_map) return models_map def log_summary(self): """ Verbose log the model summaries """ if self.predict: return for side in sorted(list(self.predictors.keys())): logger.verbose("[%s %s Summary]:", self.name.title(), side.upper()) self.predictors[side].summary( print_fn=lambda x: logger.verbose("R|%s", x)) for name, nnmeta in self.networks.items(): if nnmeta.side is not None and nnmeta.side != side: continue logger.verbose("%s:", name.title()) nnmeta.network.summary( print_fn=lambda x: logger.verbose("R|%s", x)) def do_snapshot(self): """ Perform a model snapshot """ logger.debug("Performing snapshot") self.backup.snapshot_models(self.iterations) logger.debug("Performed snapshot") def load_models(self, swapped): """ Load models from file """ logger.debug("Load model: (swapped: %s)", swapped) if not self.models_exist and not self.predict: logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir) return None if not self.models_exist and self.predict: logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir) exit(0) if not self.is_legacy: K.clear_session() model_mapping = self.map_models(swapped) for network in self.networks.values(): if not network.side: is_loaded = network.load() else: is_loaded = network.load( fullpath=model_mapping[network.side][network.type]) if not is_loaded: break if is_loaded: logger.info("Loaded model from disk: '%s'", self.model_dir) return is_loaded def save_models(self): """ Backup and save the models """ logger.debug("Backing up and saving models") save_averages = self.get_save_averages() backup_func = self.backup.backup_model if self.should_backup( save_averages) else None if backup_func: logger.info("Backing up models...") save_threads = list() for network in self.networks.values(): name = "save_{}".format(network.name) save_threads.append( MultiThread(network.save, name=name, backup_func=backup_func)) save_threads.append( MultiThread(self.state.save, name="save_state", backup_func=backup_func)) for thread in save_threads: thread.start() for thread in save_threads: if thread.has_error: logger.error(thread.errors[0]) thread.join() msg = "[Saved models]" if save_averages: lossmsg = [ "{}_{}: {:.5f}".format(self.state.loss_names[side][0], side.capitalize(), save_averages[side]) for side in sorted(list(save_averages.keys())) ] msg += " - Average since last save: {}".format(", ".join(lossmsg)) logger.info(msg) def get_save_averages(self): """ Return the average loss since the last save iteration and reset historical loss """ logger.debug("Getting save averages") avgs = dict() for side, loss in self.history.items(): if not loss: logger.debug("No loss in self.history: %s", side) break avgs[side] = sum(loss) / len(loss) self.history[side] = list() # Reset historical loss logger.debug("Average losses since last save: %s", avgs) return avgs def should_backup(self, save_averages): """ Check whether the loss averages for all losses is the lowest that has been seen. This protects against model corruption by only backing up the model if any of the loss values have fallen. TODO This is not a perfect system. If the model corrupts on save_iteration - 1 then model may still backup """ backup = True if not save_averages: logger.debug("No save averages. Not backing up") return for side, loss in save_averages.items(): if not self.state.lowest_avg_loss.get(side, None): logger.debug( "Setting initial save iteration loss average for '%s': %s", side, loss) self.state.lowest_avg_loss[side] = loss continue if backup: # Only run this if backup is true. All losses must have dropped for a valid backup backup = self.check_loss_drop(side, loss) logger.debug("Lowest historical save iteration loss average: %s", self.state.lowest_avg_loss) if backup: # Update lowest loss values to the state for side, avg_loss in save_averages.items(): logger.debug( "Updating lowest save iteration average for '%s': %s", side, avg_loss) self.state.lowest_avg_loss[side] = avg_loss logger.debug("Backing up: %s", backup) return backup def check_loss_drop(self, side, avg): """ Check whether total loss has dropped since lowest loss """ if avg < self.state.lowest_avg_loss[side]: logger.debug("Loss for '%s' has dropped", side) return True logger.debug("Loss for '%s' has not dropped", side) return False def rename_legacy(self): """ Legacy Original, LowMem and IAE models had inconsistent naming conventions Rename them if they are found and update """ legacy_mapping = { "iae": [("IAE_decoder.h5", "iae_decoder.h5"), ("IAE_encoder.h5", "iae_encoder.h5"), ("IAE_inter_A.h5", "iae_intermediate_A.h5"), ("IAE_inter_B.h5", "iae_intermediate_B.h5"), ("IAE_inter_both.h5", "iae_inter.h5")], "original": [("encoder.h5", "original_encoder.h5"), ("decoder_A.h5", "original_decoder_A.h5"), ("decoder_B.h5", "original_decoder_B.h5"), ("lowmem_encoder.h5", "original_encoder.h5"), ("lowmem_decoder_A.h5", "original_decoder_A.h5"), ("lowmem_decoder_B.h5", "original_decoder_B.h5")] } if self.name not in legacy_mapping.keys(): return logger.debug("Renaming legacy files") set_lowmem = False updated = False for old_name, new_name in legacy_mapping[self.name]: old_path = os.path.join(str(self.model_dir), old_name) new_path = os.path.join(str(self.model_dir), new_name) if os.path.exists(old_path) and not os.path.exists(new_path): logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name) os.rename(old_path, new_path) if old_name.startswith("lowmem"): set_lowmem = True updated = True if not updated: logger.debug("No legacy files to rename") return self.is_legacy = True logger.debug("Creating state file for legacy model") self.state.inputs = {"face:0": [64, 64, 3]} self.state.training_size = 256 self.state.config["coverage"] = 62.5 self.state.config["subpixel_upscaling"] = False self.state.config["reflect_padding"] = False self.state.config["mask_type"] = None self.state.config["lowmem"] = False self.encoder_dim = 1024 if set_lowmem: logger.debug( "Setting encoder_dim and lowmem flag for legacy lowmem model") self.encoder_dim = 512 self.state.config["lowmem"] = True self.state.replace_config(self.config_changeable_items) self.state.save()
def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool) -> None: self._plugin = plugin self._is_predict = is_predict self._model_dir = model_dir self._history: List[List[float]] = [[], []] # Loss histories per save iteration self._backup = Backup(self._model_dir, self._plugin.name)
class IO(): """ Model saving and loading functions. Handles the loading and saving of the plugin model from disk as well as the model backup and snapshot functions. Parameters ---------- plugin: :class:`Model` The parent plugin class that owns the IO functions. model_dir: str The full path to the model save location is_predict: bool ``True`` if the model is being loaded for inference. ``False`` if the model is being loaded for training. """ def __init__(self, plugin: "ModelBase", model_dir: str, is_predict: bool) -> None: self._plugin = plugin self._is_predict = is_predict self._model_dir = model_dir self._history: List[List[float]] = [[], []] # Loss histories per save iteration self._backup = Backup(self._model_dir, self._plugin.name) @property def _filename(self) -> str: """str: The filename for this model.""" return os.path.join(self._model_dir, f"{self._plugin.name}.h5") @property def model_exists(self) -> bool: """ bool: ``True`` if a model of the type being loaded exists within the model folder location otherwise ``False``. """ return os.path.isfile(self._filename) @property def history(self) -> List[List[float]]: """ list: list of loss histories per side for the current save iteration. """ return self._history @property def multiple_models_in_folder(self) -> Optional[List[str]]: """ :list: or ``None`` If there are multiple model types in the requested folder, or model types that don't correspond to the requested plugin type, then returns the list of plugin names that exist in the folder, otherwise returns ``None`` """ plugins = [fname.replace(".h5", "") for fname in os.listdir(self._model_dir) if fname.endswith(".h5")] test_names = plugins + [self._plugin.name] test = False if not test_names else os.path.commonprefix(test_names) == "" retval = None if not test else plugins logger.debug("plugin name: %s, plugins: %s, test result: %s, retval: %s", self._plugin.name, plugins, test, retval) return retval def _load(self) -> keras.models.Model: """ Loads the model from disk If the predict function is to be called and the model cannot be found in the model folder then an error is logged and the process exits. When loading the model, the plugin model folder is scanned for custom layers which are added to Keras' custom objects. Returns ------- :class:`keras.models.Model` The saved model loaded from disk """ logger.debug("Loading model: %s", self._filename) if self._is_predict and not self.model_exists: logger.error("Model could not be found in folder '%s'. Exiting", self._model_dir) sys.exit(1) try: model = load_model(self._filename, compile=False) except RuntimeError as err: if "unable to get link info" in str(err).lower(): msg = (f"Unable to load the model from '{self._filename}'. This may be a " "temporary error but most likely means that your model has corrupted.\n" "You can try to load the model again but if the problem persists you " "should use the Restore Tool to restore your model from backup.\n" f"Original error: {str(err)}") raise FaceswapError(msg) from err raise err except KeyError as err: if "unable to open object" in str(err).lower(): msg = (f"Unable to load the model from '{self._filename}'. This may be a " "temporary error but most likely means that your model has corrupted.\n" "You can try to load the model again but if the problem persists you " "should use the Restore Tool to restore your model from backup.\n" f"Original error: {str(err)}") raise FaceswapError(msg) from err raise err logger.info("Loaded model from disk: '%s'", self._filename) return model def save(self) -> None: """ Backup and save the model and state file. Notes ----- The backup function actually backups the model from the previous save iteration rather than the current save iteration. This is not a bug, but protection against long save times, as models can get quite large, so renaming the current model file rather than copying it can save substantial amount of time. """ logger.debug("Backing up and saving models") print("") # Insert a new line to avoid spamming the same row as loss output save_averages = self._get_save_averages() if save_averages and self._should_backup(save_averages): self._backup.backup_model(self._filename) # pylint:disable=protected-access self._backup.backup_model(self._plugin.state._filename) self._plugin.model.save(self._filename, include_optimizer=False) self._plugin.state.save() msg = "[Saved models]" if save_averages: lossmsg = [f"face_{side}: {avg:.5f}" for side, avg in zip(("a", "b"), save_averages)] msg += f" - Average loss since last save: {', '.join(lossmsg)}" logger.info(msg) def _get_save_averages(self) -> List[float]: """ Return the average loss since the last save iteration and reset historical loss """ logger.debug("Getting save averages") if not all(loss for loss in self._history): logger.debug("No loss in history") retval = [] else: retval = [sum(loss) / len(loss) for loss in self._history] self._history = [[], []] # Reset historical loss logger.debug("Average losses since last save: %s", retval) return retval def _should_backup(self, save_averages: List[float]) -> bool: """ Check whether the loss averages for this save iteration is the lowest that has been seen. This protects against model corruption by only backing up the model if both sides have seen a total fall in loss. Notes ----- This is by no means a perfect system. If the model corrupts at an iteration close to a save iteration, then the averages may still be pushed lower than a previous save average, resulting in backing up a corrupted model. Parameters ---------- save_averages: list The average loss for each side for this save iteration """ backup = True for side, loss in zip(("a", "b"), save_averages): if not self._plugin.state.lowest_avg_loss.get(side, None): logger.debug("Set initial save iteration loss average for '%s': %s", side, loss) self._plugin.state.lowest_avg_loss[side] = loss continue backup = loss < self._plugin.state.lowest_avg_loss[side] if backup else backup if backup: # Update lowest loss values to the state file # pylint:disable=unnecessary-comprehension old_avgs = {key: val for key, val in self._plugin.state.lowest_avg_loss.items()} self._plugin.state.lowest_avg_loss["a"] = save_averages[0] self._plugin.state.lowest_avg_loss["b"] = save_averages[1] logger.debug("Updated lowest historical save iteration averages from: %s to: %s", old_avgs, self._plugin.state.lowest_avg_loss) logger.debug("Should backup: %s", backup) return backup def snapshot(self) -> None: """ Perform a model snapshot. Notes ----- Snapshot function is called 1 iteration after the model was saved, so that it is built from the latest save, hence iteration being reduced by 1. """ logger.debug("Performing snapshot. Iterations: %s", self._plugin.iterations) self._backup.snapshot_models(self._plugin.iterations - 1) logger.debug("Performed snapshot")