def normalize_dtype(dtype): assert not isinstance(dtype, list), ( "Expected a single dtype-like, got a list instead.") return ( dtype.dtype if ia.is_np_array(dtype) or ia.is_np_scalar(dtype) else np.dtype(dtype) )
def _gate_dtypes(dtypes, allowed, disallowed, augmenter=None): """Verify that input dtypes are among allowed and not disallowed dtypes. Added in 0.5.0. Parameters ---------- dtypes : numpy.ndarray or iterable of numpy.ndarray or iterable of numpy.dtype One or more input dtypes to verify. Must not be a dtype function (like ``np.int64``), only a proper dtype (like ``np.dtype("int64")``). For performance reasons this is not validated. allowed : set of numpy.dtype One or more allowed dtypes. disallowed : None or set of numpy.dtype Any number of disallowed dtypes. Should not intersect with allowed dtypes. augmenter : None or imgaug.augmenters.meta.Augmenter, optional If the gating happens for an augmenter, it should be provided here. This information will be used to improve output error messages and warnings. """ if isinstance(dtypes, np.ndarray) or ia.is_np_scalar(dtypes): dtypes = set([dtypes.dtype]) elif isinstance(dtypes, list): dtypes = {arr.dtype for arr in dtypes} dts_not_explicitly_allowed = dtypes - allowed all_allowed = (not dts_not_explicitly_allowed) if all_allowed: return if disallowed is None: disallowed = set() dts_explicitly_disallowed = dts_not_explicitly_allowed.intersection( disallowed ) dts_undefined = dts_not_explicitly_allowed - disallowed if dts_explicitly_disallowed: for dtype in dts_explicitly_disallowed: if augmenter is None: raise ValueError( "Got dtype '%s', which is a forbidden dtype (%s)." % ( np.dtype(dtype).name, _dtype_names_to_string(disallowed) )) raise ValueError( "Got dtype '%s' in augmenter '%s' (class '%s'), which " "is a forbidden dtype (%s)." % ( np.dtype(dtype).name, augmenter.name, augmenter.__class__.__name__, _dtype_names_to_string(disallowed), )) if dts_undefined: for dtype in dts_undefined: if augmenter is None: ia.warn( "Got dtype '%s', which was neither explicitly allowed " "(%s), nor explicitly disallowed (%s). Generated " "outputs may contain errors." % ( dtype.name, _dtype_names_to_string(allowed), _dtype_names_to_string(disallowed), ) ) else: ia.warn( "Got dtype '%s' in augmenter '%s' (class '%s'), which was " "neither explicitly allowed (%s), nor explicitly " "disallowed (%s). Generated outputs may contain " "errors." % ( dtype.name, augmenter.name, augmenter.__class__.__name__, _dtype_names_to_string(allowed), _dtype_names_to_string(disallowed), ) )