Beispiel #1
0
    def test_log(self):
        elements = 15
        comparison = torch.arange(1, elements, dtype=torch.float64).log2()

        # logarithm of float32
        float32_tensor = ht.arange(1, elements, dtype=ht.float32)
        float32_log2 = ht.log2(float32_tensor)
        self.assertIsInstance(float32_log2, ht.tensor)
        self.assertEqual(float32_log2.dtype, ht.float32)
        self.assertEqual(float32_log2.dtype, ht.float32)
        in_range = (float32_log2._tensor__array - comparison.type(torch.float32)) < FLOAT_EPSILON
        self.assertTrue(in_range.all())

        # logarithm of float64
        float64_tensor = ht.arange(1, elements, dtype=ht.float64)
        float64_log2 = ht.log2(float64_tensor)
        self.assertIsInstance(float64_log2, ht.tensor)
        self.assertEqual(float64_log2.dtype, ht.float64)
        self.assertEqual(float64_log2.dtype, ht.float64)
        in_range = (float64_log2._tensor__array - comparison) < FLOAT_EPSILON
        self.assertTrue(in_range.all())

        # logarithm of ints, automatic conversion to intermediate floats
        int32_tensor = ht.arange(1, elements, dtype=ht.int32)
        int32_log2 = ht.log2(int32_tensor)
        self.assertIsInstance(int32_log2, ht.tensor)
        self.assertEqual(int32_log2.dtype, ht.float64)
        self.assertEqual(int32_log2.dtype, ht.float64)
        in_range = (int32_log2._tensor__array - comparison) < FLOAT_EPSILON
        self.assertTrue(in_range.all())

        # log2arithm of longs, automatic conversion to intermediate floats
        int64_tensor = ht.arange(1, elements, dtype=ht.int64)
        int64_log2 = ht.log2(int64_tensor)
        self.assertIsInstance(int64_log2, ht.tensor)
        self.assertEqual(int64_log2.dtype, ht.float64)
        self.assertEqual(int64_log2.dtype, ht.float64)
        in_range = (int64_log2._tensor__array - comparison) < FLOAT_EPSILON
        self.assertTrue(in_range.all())

        # check exceptions
        with self.assertRaises(TypeError):
            ht.log2([1, 2, 3])
        with self.assertRaises(TypeError):
            ht.log2('hello world')
Beispiel #2
0
    def test_log2(self):
        elements = 15
        tmp = torch.arange(1,
                           elements,
                           dtype=torch.float64,
                           device=self.device.torch_device).log2()
        comparison = ht.array(tmp)

        # logarithm of float32
        float32_tensor = ht.arange(1, elements, dtype=ht.float32)
        float32_log2 = ht.log2(float32_tensor)
        self.assertIsInstance(float32_log2, ht.DNDarray)
        self.assertEqual(float32_log2.dtype, ht.float32)
        self.assertEqual(float32_log2.dtype, ht.float32)
        self.assertTrue(
            ht.allclose(float32_log2, comparison.astype(ht.float32)))

        # logarithm of float64
        float64_tensor = ht.arange(1, elements, dtype=ht.float64)
        float64_log2 = ht.log2(float64_tensor)
        self.assertIsInstance(float64_log2, ht.DNDarray)
        self.assertEqual(float64_log2.dtype, ht.float64)
        self.assertEqual(float64_log2.dtype, ht.float64)
        self.assertTrue(ht.allclose(float64_log2, comparison))

        # logarithm of ints, automatic conversion to intermediate floats
        int32_tensor = ht.arange(1, elements, dtype=ht.int32)
        int32_log2 = ht.log2(int32_tensor)
        self.assertIsInstance(int32_log2, ht.DNDarray)
        self.assertEqual(int32_log2.dtype, ht.float64)
        self.assertEqual(int32_log2.dtype, ht.float64)
        self.assertTrue(ht.allclose(int32_log2, comparison))

        # logarithm of longs, automatic conversion to intermediate floats
        int64_tensor = ht.arange(1, elements, dtype=ht.int64)
        int64_log2 = int64_tensor.log2()
        self.assertIsInstance(int64_log2, ht.DNDarray)
        self.assertEqual(int64_log2.dtype, ht.float64)
        self.assertEqual(int64_log2.dtype, ht.float64)
        self.assertTrue(ht.allclose(int64_log2, comparison))

        # check exceptions
        with self.assertRaises(TypeError):
            ht.log2([1, 2, 3])
        with self.assertRaises(TypeError):
            ht.log2("hello world")