def _get_device_tags(): """returns a set of tags definded for the device under test""" if is_device_rocm(): device_tags = {device_under_test(), "rocm"} elif is_device_cuda(): device_tags = {device_under_test(), "cuda"} else: device_tags = {device_under_test()} return device_tags
def supported_dtypes(): if device_under_test() == "tpu": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64} elif device_under_test() == "iree": types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32, np.float32} else: types = {np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, _dtypes.bfloat16, np.float16, np.float32, np.float64, np.complex64, np.complex128} if not config.x64_enabled: types -= {np.uint64, np.int64, np.float64, np.complex128} return types
def if_device_under_test(device_type: Union[str, Sequence[str]], if_true, if_false): """Chooses `if_true` of `if_false` based on device_under_test.""" if device_under_test() in ([device_type] if isinstance(device_type, str) else device_type): return if_true else: return if_false