Exemplo n.º 1
0
    def __init__(self, saved_model: keras.models.Model, switch_sides: bool) -> None:
        logger.debug("Initializing: %s (saved_model: %s, switch_sides: %s)",
                     self.__class__.__name__, saved_model, switch_sides)
        self._config = saved_model.get_config()

        self._input_idx = 1 if switch_sides else 0
        self._output_idx = 0 if switch_sides else 1

        self._input_names = [inp[0] for inp in self._config["input_layers"]]
        self._model = self._make_inference_model(saved_model)
        logger.debug("Initialized: %s", self.__class__.__name__)
Exemplo n.º 2
0
    def check_model_precision(self,
                              model: keras.models.Model,
                              state: "State") -> keras.models.Model:
        """ Check the model's precision.

        If this is a new model, then
        Rewrite an existing model's training precsion mode from mixed-float16 to float32 or
        vice versa.

        This is not easy to do in keras, so we edit the model's config to change the dtype policy
        for compatible layers. Create a new model from this config, then port the weights from the
        old model to the new model.

        Parameters
        ----------
        model: :class:`keras.models.Model`
            The original saved keras model to rewrite the dtype
        state: ~:class:`plugins.train.model._base.model.State`
            The State information for the model

        Returns
        -------
        :class:`keras.models.Model`
            The original model with the datatype updated
        """
        if get_backend() == "amd":  # Mixed precision not supported on amd
            return model

        if self.use_mixed_precision and not state.mixed_precision_layers:
            # Switching to mixed precision on a model which was started in FP32 prior to the
            # ability to switch between precisions on a saved model is not supported as we
            # do not have the compatible layer names
            logger.warning("Switching from Full Precision to Mixed Precision is not supported on "
                           "older model files. Reverting to Full Precision.")
            return model

        config = model.get_config()

        if not self.use_mixed_precision and not state.mixed_precision_layers:
            # Switched to Full Precision, get compatible layers from model if not already stored
            state.add_mixed_precision_layers(self._get_mixed_precision_layers(config["layers"]))

        self._switch_precision(config["layers"], state.mixed_precision_layers)

        new_model = keras.models.Model().from_config(config)
        new_model.set_weights(model.get_weights())
        logger.info("Mixed precision has been updated from '%s' to '%s'",
                    not self.use_mixed_precision, self.use_mixed_precision)
        del model
        return new_model