Example #1
0
 def test_not_in_range(self):
     validator = StringValidator(
         not_in_range=['CPython', 'PyPy', 'IronPython', 'Jython', 'Cython']
     )
     self.assertTrue(validator('Ruby'))
     self.assertTrue(validator('Java'))
     self.assertFalse(validator('CPython'))
Example #2
0
 def test_re(self):
     # Allowed characters are: latin alphabet letters and digits
     validator = StringValidator(re_pattern='^[a-zA-Z0-9]+$')
     self.assertTrue(validator('pyvalid'))
     self.assertTrue(validator('42'))
     self.assertFalse(validator('__pyvalid__'))
     # Regular expression is broken
     validator = StringValidator(re_pattern=':)')
     self.assertFalse(validator('pyvalid'))
     self.assertFalse(validator(':)'))
     # Try to use regular expression with flag
     validator = StringValidator(
         re_pattern='^pyvalid$', re_flags=re.IGNORECASE
     )
     self.assertTrue(validator('pyvalid'))
     self.assertTrue(validator('PyValid'))
     self.assertFalse(validator('42'))
Example #3
0
 def test_mixed(self):
     validator = StringValidator(
         min_len=6, max_len=64,
         not_in_range=['password', 'qwerty', '123456789', 'sunshine'],
     )
     self.assertTrue(validator('Super_Mega_Strong_Password_2000'))
     self.assertTrue(validator('_'*6))
     self.assertFalse(validator('_'*3))
     self.assertFalse(validator('_'*128))
     self.assertFalse(validator('sunshine'))
Example #4
0
 def test_min_len(self):
     validator = StringValidator(min_len=2)
     self.assertTrue(validator('Python'))
     self.assertTrue(validator('Py'))
     self.assertFalse(validator('P'))
Example #5
0
 def test_max_len(self):
     validator = StringValidator(max_len=6)
     self.assertTrue(validator(str()))
     self.assertTrue(validator('Python'))
     self.assertFalse(validator('Python3'))
Example #6
0
class TensorValidator(AbstractValidator):
    """
    Performs certain checks to check if the given tensor is valid or not.

    Example:

    .. code-block:: python

        X = torch.Tensor([
            [0.0500, 0.0000, 0.0688, 0.0843, 0.0000, 0.0000, 0.1896, 0.0105],
            [0.0500, 0.0000, 0.0528, 0.1810, 0.0000, 0.0000, 0.1470, 0.0000]
        ])
        validator = TensorValidator(
            tensor_type="torch.FloatTensor", empty_check=True, nan_check=True, dim=2
        )

        @accepts(validator)
        def example(X):
            pass

    """
    @classmethod
    def tensor_type_checker(cls, val, tensor_type):
        """Checks the tensor types with CPU variants.

        Args:
            val (torch.Tensor):
                Tensor whose type is to be validated.
            tensor_type (str):
                Expected type of tensor.
                Ex: "torch.IntTensor", "torch.FloatTensor", "torch.ByteTensor".

        Returns (bool):
            True:
                If the type of given tensor matches the required type.
            False:
                If the type of given tensor does not match the required type.

        """
        return val.type() == tensor_type

    @classmethod
    def dimension_checker(cls, val, dim):
        """Checks if given tensor is of dimension <dim>.

        Args:
            val (torch.Tensor):
                Tensor whose dimension is to be validated.
            dim (int):
                Expected dimension of the tensor.

        Returns (bool):
            True:
                If the given tensor is of required dimension.
            False:
                If the given tensor is not of required dimension.

        """
        return val.dim() == dim

    @classmethod
    def empty_checker(cls, val, empty_allowed):
        """Checks if the tensor is empty or not.

        Args:
            val (torch.Tensor):
                Tensor whose contents needs to be validated.
            empty_allowed (bool):
                If this flag is set to ``False``, this method raises exception and
                terminates the execution if the tensor is empty.
                If set to ``True``, it raises a warning and continues with
                the execution.

        Returns (bool):
            True:
                If the tensor is not empty.
            False:
                If the tensor is empty.

        """
        if not empty_allowed:
            return val.nelement() != 0
        else:
            warnings.warn(
                "Tensor is empty, but does not impact the execution.")
            return True

    @classmethod
    def nan_checker(cls, val, nans_allowed):
        """Checks if the tensor has np.NaN values or not.

        Args:
            val (torch.Tensor):
                Tensor to be validated.
            nans_allowed (bool):
                If this flag is set to ``False``, this method raises exception and
                terminates the execution if the tensor has NaNs.
                If set to ``True``, it raises a warning and continues with
                the execution.

        Returns (bool):
            True:
                If the given tensor is free of NaNs.
            False:
                If the given tensor contains NaNs.

        """
        if not nans_allowed:
            return torch.isnan(val).sum().item() == 0
        else:
            warnings.warn(
                "Tensor contains NaN values, but does not impact the execution."
            )
            return True

    @property
    def checkers(self):
        return self.__checkers

    @accepts(object,
             tensor_type=StringValidator(in_range=[
                 "torch.CharTensor", "torch.IntTensor", "torch.ShortTensor",
                 "torch.LongTensor", "torch.FloatTensor", "torch.DoubleTensor",
                 "torch.ByteTensor", "torch.BoolTensor", "torch.HalfTensor"
             ]),
             dim=int,
             empty_check=bool,
             nan_check=bool)
    def __init__(self, **kwargs):
        self.__checkers = {
            TensorValidator.tensor_type_checker:
            [kwargs.get('tensor_type', None)],
            TensorValidator.dimension_checker: [kwargs.get('dim', None)],
            TensorValidator.empty_checker: [kwargs.get('empty_allowed', None)],
            TensorValidator.nan_checker: [kwargs.get('nans_allowed', None)],
        }
        AbstractValidator.__init__(self, allowed_types=torch.Tensor)