예제 #1
0
def lower_precision_dtype(model):
    if version(tf.version.VERSION) >= version("2.4.0"):
        for layer in model.layers:
            if (layer.compute_dtype in [tf.float16, tf.bfloat16]) or \
               (isinstance(layer, tf.keras.Model) and is_mixed_precision(layer)):
                return layer.compute_dtype
    return model.dtype  # pragma: no cover
예제 #2
0
def is_mixed_precision(model) -> bool:
    """Check whether the model has any lower precision variable or not.

    Args:
        model (tf.keras.Model): A model instance.

    Returns:
        bool: When the model has any lower precision variable, True.
    """
    if version(tf.version.VERSION) >= version("2.4.0"):
        for layer in model.layers:
            if (layer.compute_dtype == tf.float16) or \
               (isinstance(layer, tf.keras.Model) and is_mixed_precision(layer)):
                return True
    return False
예제 #3
0
def num_of_gpus() -> Tuple[int, int]:
    """Return the number of physical and logical gpus.

    Returns:
        Tuple[int, int]: A tuple of the number of physical and logical gpus.
    """
    if version(tf.version.VERSION) < version("2.1.0"):
        list_physical_devices = tf.config.experimental.list_physical_devices
        list_logical_devices = tf.config.experimental.list_logical_devices
    else:
        list_physical_devices = tf.config.list_physical_devices
        list_logical_devices = tf.config.list_logical_devices
    physical_gpus = list_physical_devices('GPU')
    if physical_gpus:
        logical_gpus = list_logical_devices('GPU')
        return len(physical_gpus), len(logical_gpus)
    else:
        return 0, 0
예제 #4
0
    def test_num_of_gpus(self, monkeypatch):
        def list_physical_devices(name):
            return ['dummy-a', 'dummy-b']

        def list_logical_devices(name):
            return ['a1', 'a2', 'b1', 'b2']

        if version(tf.version.VERSION) < version("2.1.0"):
            monkeypatch.setattr(tf.config.experimental,
                                "list_physical_devices", list_physical_devices)
            monkeypatch.setattr(tf.config.experimental, "list_logical_devices",
                                list_logical_devices)

        else:
            monkeypatch.setattr(tf.config, "list_physical_devices",
                                list_physical_devices)
            monkeypatch.setattr(tf.config, "list_logical_devices",
                                list_logical_devices)
        a, b, = num_of_gpus()
        assert a == 2
        assert b == 4
예제 #5
0
    def test_num_of_gpus_if_no_gpus(self, monkeypatch):
        def list_physical_devices(name):
            return None

        def list_logical_devices(name):
            return None

        if version(tf.version.VERSION) < version("2.1.0"):
            monkeypatch.setattr(tf.config.experimental,
                                "list_physical_devices", list_physical_devices)
            monkeypatch.setattr(tf.config.experimental, "list_logical_devices",
                                list_logical_devices)

        else:
            monkeypatch.setattr(tf.config, "list_physical_devices",
                                list_physical_devices)
            monkeypatch.setattr(tf.config, "list_logical_devices",
                                list_logical_devices)
        a, b, = num_of_gpus()
        assert a == 0
        assert b == 0
예제 #6
0
from abc import ABC, abstractmethod
from typing import Union

import tensorflow as tf
from packaging.version import parse as version

if version(tf.version.VERSION) < version("2.6.0rc0"):
    from tensorflow.python.keras.layers.convolutional import Conv
else:
    from keras.layers.convolutional import Conv

from . import find_layer


class ModelModifier(ABC):
    """Abstract class for defining a model modifier.
    """
    @abstractmethod
    def __call__(self, model) -> Union[None, tf.keras.Model]:
        """Implement modification to the model before processing gradient descent.

        Args:
            model: A model instance.

        Raises:
            NotImplementedError: This method must be overwritten.

        Returns: Modified model or None.
        """
        raise NotImplementedError()
예제 #7
0
import numpy as np
import pytest
import tensorflow as tf
from packaging.version import parse as version
from tensorflow.keras.models import load_model

from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.gradcam_plus_plus import GradcamPlusPlus
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.scorecam import Scorecam
from tf_keras_vis.utils.scores import BinaryScore, CategoricalScore
from tf_keras_vis.utils.test import (NO_ERROR, assert_raises, dummy_sample, mock_conv_model,
                                     mock_conv_model_with_float32_output, mock_multiple_io_model,
                                     score_with_list, score_with_tensor, score_with_tuple)

if version(tf.version.VERSION) >= version("2.4.0"):
    from tensorflow.keras.mixed_precision import set_global_policy


@pytest.fixture(scope='function', params=[Gradcam, GradcamPlusPlus, Scorecam])
def xcam(request):
    global Xcam
    Xcam = request.param
    yield
    Xcam = None


@pytest.fixture(scope='function')
def saliency():
    global Xcam
    Xcam = Saliency
예제 #8
0
def check_update_dependencies():
    # Check if Git and conda are installed
    print("\n\nINFO: Checking for dependencies:")
    print("N" * 79)
    if "packaging" not in str(
            subprocess.check_output([f"{MC}", "list", "--json"])):
        try:
            print("\n\n--->Installing packaging"
                  )  # In case of miniconda install packaging
            os.system(f"{MC} install packaging -q -y -n base -c defaults")
        except BaseException as e:
            print(e)
            raise SystemError("Could not install packaging, aborting...")

    conda_version = str(subprocess.check_output(
        [f"{MC}", "-V"])).split(" ")[1].split("\\n")[0]
    python_version = (str(subprocess.check_output(
        ["python", "-V"])).split(" ")[1].split("\\n")[0]).strip("\\r")
    pip_version = str(subprocess.check_output(["pip", "-V"])).split(" ")[1]

    if version(conda_version) < version("4.10.3"):
        try:
            print("\n\n--->Updating base conda")
            os.system(f"{MC} update -q -y -n base -c defaults conda")
            print("\n--->conda update... OK")
        except BaseException as e:
            print(e)
            raise SystemError("Could not update conda, aborting install...")

    if version(python_version) < version("3.7.11"):
        try:
            print("\n\n--->Updating base environment python")
            os.system(f"{MC} update -q -y -n base -c defaults python>=3.8")
            print("\n--->python update... OK")
        except BaseException as e:
            print(e)
            raise SystemError("Could not update python, aborting install...")

    if version(pip_version) < version("20.2.4"):
        try:
            print("\n\n--->Reinstalling pip, setuptools, wheel...")
            os.system(
                f"{MC} install -q -y -n base -c defaults pip>=20.2.4 --force-reinstall"
            )
            os.system(
                f"{MC} update -q -y -n base -c defaults setuptools wheel")
            print("\n--->pip, setuptools, wheel upgrade... OK")
        except BaseException as e:
            print(e)
            raise SystemError(
                "Could not reinstall pip, setuptools, wheel aborting install..."
            )

    try:
        subprocess.check_output(["git", "--version"])
    except BaseException as e:
        print(e, "\ngit not found trying to install...")
        try:
            print("\n\n--->Installing git")
            os.system(f"{MC} install -q -y git")
            print("\n\n--->git... OK")
        except BaseException as e:
            print(e)
            raise SystemError("Could not install git, aborting install...")

    # try:
    #     print("\n\n--->Updating remaning base packages...")
    #     os.system(f"{MC} update -q -y -n base -c defaults --all")
    #     print("\n--->Update of remaining packages... OK")
    # except BaseException as e:
    #     print(e)
    #     print("Could not update remaining packages, trying to continue install...")

    print("N" * 79)
    print("All dependencies OK.")
    return 0
예제 #9
0
def mixed_precision(request):
    if version(tf.version.VERSION) >= version("2.4.0"):
        tf.keras.mixed_precision.set_global_policy(request.param)
    yield
    if version(tf.version.VERSION) >= version("2.4.0"):
        tf.keras.mixed_precision.set_global_policy("float32")
예제 #10
0
def _get_supported_policies():
    if version(tf.version.VERSION) < version("2.4.0"):
        return ["float32"]
    else:
        return ["float32", "mixed_float16"]
def test_library_versions():
    assert version(np.__version__) >= version("1.15")
    assert version(scipy.__version__) >= version("1.5")
    assert version(dask.__version__) >= version("2.20")
    assert version(snakemake.__version__) >= version("5.28")
    assert version(numba.__version__) >= version("0.50")
예제 #12
0
import pytest
import tensorflow as tf
from packaging.version import parse as version
from tensorflow.keras.models import load_model

from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils.scores import BinaryScore, CategoricalScore
from tf_keras_vis.utils.test import (dummy_sample, mock_conv_model,
                                     mock_conv_model_with_float32_output,
                                     mock_multiple_io_model)

if version(tf.version.VERSION) >= version("2.4.0"):
    from tensorflow.keras.mixed_precision import set_global_policy


class TestSaliency():
    @pytest.mark.parametrize("keepdims,expected", [
        (False, (1, 8, 8)),
        (True, (1, 8, 8, 3)),
    ])
    @pytest.mark.usefixtures("mixed_precision")
    def test__call__if_keepdims_is_(self, keepdims, expected, conv_model):
        saliency = Saliency(conv_model)
        result = saliency(CategoricalScore(0),
                          dummy_sample((1, 8, 8, 3)),
                          keepdims=keepdims)
        assert result.shape == expected

    @pytest.mark.parametrize("smooth_samples", [1, 3, 100])
    @pytest.mark.usefixtures("mixed_precision")
    def test__call__if_smoothing_is_active(self, smooth_samples, conv_model):
예제 #13
0
 def __init__(self, *args, **kwargs):
     if version(sqlalchemy.__version__) >= version(self.max_sqla_version):
         init_filter_event(self.property_name)
     super(PublicQuery, self).__init__(*args, **kwargs)