예제 #1
0
def _get_device_tags():
    """returns a set of tags definded for the device under test"""
    if is_device_rocm():
        device_tags = {device_under_test(), "rocm"}
    elif is_device_cuda():
        device_tags = {device_under_test(), "cuda"}
    else:
        device_tags = {device_under_test()}
    return device_tags
예제 #2
0
파일: test_util.py 프로젝트: romanngg/jax
def supported_dtypes():
  if device_under_test() == "tpu":
    types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
             np.uint32, _dtypes.bfloat16, np.float16, np.float32, np.complex64}
  elif device_under_test() == "iree":
    types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
             np.uint32, np.float32}
  else:
    types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
             np.uint8, np.uint16, np.uint32, np.uint64,
             _dtypes.bfloat16, np.float16, np.float32, np.float64,
             np.complex64, np.complex128}
  if not config.x64_enabled:
    types -= {np.uint64, np.int64, np.float64, np.complex128}
  return types
예제 #3
0
def if_device_under_test(device_type: Union[str, Sequence[str]], if_true,
                         if_false):
    """Chooses `if_true` of `if_false` based on device_under_test."""
    if device_under_test() in ([device_type] if isinstance(device_type, str)
                               else device_type):
        return if_true
    else:
        return if_false