예제 #1
0
 def __init__(self, padding=None):
     super(PaddingPropertyHolder, self).__init__()
     self.padding = validate_parameter(padding, ALLOWED_PADDINGS, "z")
예제 #2
0
 def __init__(self, interpolation=None):
     super(InterpolationPropertyHolder, self).__init__()
     self.interpolation = validate_parameter(interpolation,
                                             ALLOWED_INTERPOLATIONS,
                                             "bilinear")
예제 #3
0
def test_parameter_validation_raises_error_when_default_value_is_wrong_type():
    with pytest.raises(TypeError):
        slu.validate_parameter(None, {1, 2}, ("10", "inherit"), int)
예제 #4
0
def test_validate_parameter_raises_value_errors(parameter):
    with pytest.raises(ValueError):
        slu.validate_parameter(parameter, {1, 2}, 1, basic_type=int)
예제 #5
0
def test_parameter_validation_raises_error_when_default_type_is_wrong():
    with pytest.raises(ValueError):
        slu.validate_parameter(None, {1, 2}, (10, "12345"), int)
예제 #6
0
def test_parameter_validation_raises_error_when_types_dont_match():
    with pytest.raises(NotImplementedError):
        slu.validate_parameter({1, 2}, 10, int)
예제 #7
0
    def __init__(self, data, fmt, transform_settings=None):
        if len(fmt) == 1 and not isinstance(data, tuple):
            if not isinstance(data, list):
                data = (data, )
            else:
                raise TypeError

        if not isinstance(data, tuple):
            raise TypeError

        if len(data) != len(fmt):
            raise ValueError

        if transform_settings is not None:
            if not isinstance(transform_settings, dict):
                raise TypeError
        else:
            transform_settings = {}

        # Element-wise settings
        # If no settings provided for certain items, they will be created
        for idx in range(len(data)):
            if idx not in transform_settings:
                transform_settings[idx] = {}

            if fmt[idx] == "I" or fmt[idx] == "M":
                val = ("nearest", "strict") if fmt[idx] == "M" else None
                if "interpolation" not in transform_settings[idx]:
                    transform_settings[idx][
                        "interpolation"] = validate_parameter(
                            val, ALLOWED_INTERPOLATIONS, "bilinear", str, True)
                else:
                    transform_settings[idx][
                        "interpolation"] = validate_parameter(
                            (transform_settings[idx]["interpolation"],
                             "strict"),
                            ALLOWED_INTERPOLATIONS,
                            "bilinear",
                            str,
                            True,
                        )

                if "padding" not in transform_settings[idx]:
                    transform_settings[idx]["padding"] = validate_parameter(
                        None, ALLOWED_PADDINGS, "z", str, True)
                else:
                    transform_settings[idx]["padding"] = validate_parameter(
                        (transform_settings[idx]["padding"], "strict"),
                        ALLOWED_PADDINGS,
                        "z",
                        str,
                        True,
                    )
            else:
                if "interpolation" in transform_settings[
                        idx] or "padding" in transform_settings[idx]:
                    raise TypeError

        if len(data) != len(transform_settings):
            raise ValueError

        for t in fmt:
            if t not in ALLOWED_TYPES:
                raise TypeError(
                    f"The found type was {t}, but needs to be one of {ALLOWED_TYPES}"
                )

        self.__data = data
        self.__fmt = fmt
        self.__transform_settings = transform_settings
예제 #8
0
def test_parameter_validation_raises_error_when_default_value_is_wrong_type():
    with pytest.raises(TypeError):
        validate_parameter(None, {1, 2}, ('10', 'inherit'), int)
예제 #9
0
파일: _data.py 프로젝트: imelekhov/solt
    def __init__(self, data, fmt, transform_settings=None):
        if len(fmt) == 1 and not isinstance(data, tuple):
            if not isinstance(data, list):
                data = (data, )
            else:
                raise TypeError

        if not isinstance(data, tuple):
            raise TypeError

        if len(data) != len(fmt):
            raise ValueError

        if transform_settings is not None:
            if not isinstance(transform_settings, dict):
                raise TypeError
        else:
            transform_settings = {}

        # Element-wise settings
        # If no settings provided for certain items, they will be created
        for idx in range(len(data)):
            if idx not in transform_settings:
                transform_settings[idx] = {}

            if fmt[idx] == "I" or fmt[idx] == "M":
                val = ("nearest", "strict") if fmt[idx] == "M" else None
                if "interpolation" not in transform_settings[idx]:
                    transform_settings[idx][
                        "interpolation"] = validate_parameter(
                            val, ALLOWED_INTERPOLATIONS, "bilinear", str, True)
                else:
                    transform_settings[idx][
                        "interpolation"] = validate_parameter(
                            (transform_settings[idx]["interpolation"],
                             "strict"),
                            ALLOWED_INTERPOLATIONS,
                            "bilinear",
                            str,
                            True,
                        )

                if "padding" not in transform_settings[idx]:
                    transform_settings[idx]["padding"] = validate_parameter(
                        None, ALLOWED_PADDINGS, "z", str, True)
                else:
                    transform_settings[idx]["padding"] = validate_parameter(
                        (transform_settings[idx]["padding"], "strict"),
                        ALLOWED_PADDINGS,
                        "z",
                        str,
                        True,
                    )
            else:
                if "interpolation" in transform_settings[
                        idx] or "padding" in transform_settings[idx]:
                    raise TypeError

        if len(data) != len(transform_settings):
            raise ValueError

        for t in fmt:
            if t not in ALLOWED_TYPES:
                raise TypeError

        self.__data = data
        self.__fmt = fmt
        self.__transform_settings = transform_settings

        self.__imagenet_mean = torch.tensor(
            (0.485, 0.456, 0.406)).view(3, 1, 1)
        self.__imagenet_std = torch.tensor((0.229, 0.224, 0.225)).view(3, 1, 1)