Example #1
0
    def _validate_encoder_architecture(self):
        """ Validate that the requested architecture is a valid choice for the running system
        configuration.

        If the selection is not valid, an error is logged and system exits.
        """
        arch = self.config["enc_architecture"].lower()
        model = _MODEL_MAPPING.get(arch)
        if not model:
            raise FaceswapError(
                f"'{arch}' is not a valid choice for encoder architecture. Choose "
                f"one of {list(_MODEL_MAPPING.keys())}.")

        if get_backend() == "amd" and model.get("no_amd"):
            valid = [
                k for k, v in _MODEL_MAPPING.items() if not v.get('no_amd')
            ]
            raise FaceswapError(
                f"'{arch}' is not compatible with the AMD backend. Choose one of "
                f"{valid}.")

        tf_ver = get_tf_version()
        tf_min = model.get("tf_min", 2.0)
        if get_backend() != "amd" and tf_ver < tf_min:
            raise FaceswapError(
                f"{arch}' is not compatible with your version of Tensorflow. The "
                f"minimum version required is {tf_min} whilst you have version "
                f"{tf_ver} installed.")
Example #2
0
    def _test_for_tf_version(self):
        """ Check that the required Tensorflow version is installed.

        Raises
        ------
        FaceswapError
            If Tensorflow is not found, or is not between versions 2.4 and 2.8
        """
        amd_ver = 2.2
        min_ver = 2.4
        max_ver = 2.8
        try:
            # Ensure tensorflow doesn't pin all threads to one core when using Math Kernel Library
            os.environ["TF_MIN_GPU_MULTIPROCESSOR_COUNT"] = "4"
            os.environ["KMP_AFFINITY"] = "disabled"
            import tensorflow as tf  # noqa pylint:disable=import-outside-toplevel,unused-import
        except ImportError as err:
            if "DLL load failed while importing" in str(err):
                msg = (
                    f"A DLL library file failed to load. Make sure that you have Microsoft Visual "
                    "C++ Redistributable (2015, 2017, 2019) installed for your machine from: "
                    "https://support.microsoft.com/en-gb/help/2977003. Original error: "
                    f"{str(err)}")
            else:
                msg = (
                    f"There was an error importing Tensorflow. This is most likely because you do "
                    "not have TensorFlow installed, or you are trying to run tensorflow-gpu on a "
                    "system without an Nvidia graphics card. Original import "
                    f"error: {str(err)}")
            self._handle_import_error(msg)

        tf_ver = get_tf_version()
        backend = get_backend()
        if backend != "amd" and tf_ver < min_ver:
            msg = (
                f"The minimum supported Tensorflow is version {min_ver} but you have version "
                f"{tf_ver} installed. Please upgrade Tensorflow.")
            self._handle_import_error(msg)
        if backend != "amd" and tf_ver > max_ver:
            msg = (
                f"The maximum supported Tensorflow is version {max_ver} but you have version "
                f"{tf_ver} installed. Please downgrade Tensorflow.")
            self._handle_import_error(msg)
        if backend == "amd" and tf_ver != amd_ver:
            msg = (
                f"The supported Tensorflow version for AMD cards is {amd_ver} but you have "
                "version {tf_ver} installed. Please install the correct version."
            )
            self._handle_import_error(msg)
        logger.debug("Installed Tensorflow Version: %s", tf_ver)
Example #3
0
    def _log_tensorboard(self, loss):
        """ Log current loss to Tensorboard log files

        Parameters
        ----------
        loss: list
            The list of loss ``floats`` output from the model
        """
        if not self._tensorboard:
            return
        logger.trace("Updating TensorBoard log")
        logs = {
            log[0]: log[1]
            for log in zip(self._model.state.loss_names, loss)
        }

        self._tensorboard.on_train_batch_end(self._model.iterations, logs=logs)
        if get_tf_version() == 2.8:
            # Bug in TF 2.8 where batch recording got deleted.
            # ref: https://github.com/keras-team/keras/issues/16173
            for name, value in logs.items():
                tf.summary.scalar("batch_" + name,
                                  value,
                                  step=self._model._model._train_counter)  # pylint:disable=protected-access
Example #4
0
import tensorflow as tf

from lib.model import losses, optimizers
from lib.utils import get_backend, get_tf_version

if get_backend() == "amd":
    import keras
    from keras import losses as k_losses
    from keras import backend as K
else:
    # Ignore linting errors from Tensorflow's thoroughly broken import system
    from tensorflow import keras
    from tensorflow.keras import losses as k_losses  # pylint:disable=import-error
    from tensorflow.keras import backend as K  # pylint:disable=import-error

if get_tf_version() < 2.4:
    import tensorflow.keras.mixed_precision.experimental as mixedprecision  # noqa pylint:disable=import-error,no-name-in-module
else:
    import tensorflow.keras.mixed_precision as mixedprecision  # noqa pylint:disable=import-error,no-name-in-module

if TYPE_CHECKING:
    from argparse import Namespace

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


@dataclass
class LossClass:
    """ Typing class for holding loss functions.

    Parameters