Exemple #1
0
    def test_type(self, out_dtype, input_data, expected_type):

        result = CastToType(dtype=out_dtype)(input_data)
        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))

        result = CastToType()(input_data, out_dtype)
        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))
Exemple #2
0
    def test_type_cupy(self, out_dtype, input_data, expected_type):
        input_data = cp.asarray(input_data)

        result = CastToType(dtype=out_dtype)(input_data)
        self.assertTrue(isinstance(result, cp.ndarray))
        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))

        result = CastToType()(input_data, out_dtype)
        self.assertTrue(isinstance(result, cp.ndarray))
        self.assertEqual(result.dtype, get_equivalent_dtype(expected_type, type(result)))
 def test_native_type(self):
     """the get_equivalent_dtype currently doesn't change the build-in type"""
     n_type = [float, int, bool]
     for n in n_type:
         for im_dtype in DTYPES:
             out_dtype = get_equivalent_dtype(n, type(im_dtype))
             self.assertEqual(out_dtype, n)
 def test_from_string(self, dtype_str, expected_np):
     expected_pt = get_equivalent_dtype(expected_np, torch.Tensor)
     # numpy
     dtype = get_numpy_dtype_from_string(dtype_str)
     self.assertEqual(dtype, expected_np)
     # torch
     dtype = get_torch_dtype_from_string(dtype_str)
     self.assertEqual(dtype, expected_pt)
Exemple #5
0
    def __init__(self,
                 orig_labels: Sequence,
                 target_labels: Sequence,
                 dtype: DtypeLike = np.float32) -> None:
        """
        Args:
            orig_labels: original labels that map to others.
            target_labels: expected label values, 1: 1 map to the `orig_labels`.
            dtype: convert the output data to dtype, default to float32.

        """
        if len(orig_labels) != len(target_labels):
            raise ValueError(
                "orig_labels and target_labels must have the same length.")
        if all(o == z for o, z in zip(orig_labels, target_labels)):
            raise ValueError(
                "orig_labels and target_labels are exactly the same, should be different to map."
            )

        self.orig_labels = orig_labels
        self.target_labels = target_labels
        self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)
 def test_get_equivalent_dtype(self, im, input_dtype):
     out_dtype = get_equivalent_dtype(input_dtype, type(im))
     self.assertEqual(out_dtype, im.dtype)