def test_clear(self): w = dragon.Workspace() with w.as_default(): x = dragon.EagerTensor(1) self.assertEqual(x.size, 1) w.clear() self.assertEqual(x.size, 0)
def test_reset_tensor(self): w = dragon.Workspace() with w.as_default(): x = dragon.EagerTensor(1) self.assertEqual(x.size, 1) w.reset_tensor(x) self.assertEqual(x.size, 0)
def test_properties(self): a, b = dragon.Tensor(()), dragon.EagerTensor(0) self.assertEqual(dragon.Tensor(()).ndim, 0) self.assertEqual(dragon.Tensor(shape=(2, )).ndim, 1) self.assertEqual(dragon.Tensor(None).shape, None) self.assertEqual(dragon.Tensor(shape=(2, )).shape, (2, )) self.assertEqual(dragon.Tensor(None).size, 0) self.assertEqual(dragon.Tensor(()).size, 1) self.assertEqual(dragon.Tensor(shape=(2, None)).size, math.inf) self.assertEqual(dragon.Tensor(shape=(2, )).size, 2) self.assertEqual(dragon.Tensor(None, None).dtype, None) self.assertEqual(dragon.Tensor(None, dtype='float32').dtype, 'float32') self.assertEqual(dragon.EagerTensor(shape=(2, )).ndim, 1) self.assertEqual(dragon.EagerTensor(shape=(2, )).shape, (2, )) self.assertEqual(dragon.EagerTensor(shape=(2, )).size, 2) self.assertEqual( dragon.EagerTensor(shape=(2, ), dtype='float32').dtype, 'float32') self.assertEqual(dragon.EagerTensor().device, dragon.EagerTensor().device) self.assertNotEqual(a.__hash__(), b.__hash__()) self.assertNotEqual(a.__repr__(), b.__repr__()) self.assertNotEqual(b.__repr__(), dragon.EagerTensor((2, )).__repr__()) self.assertEqual(int(a.constant().set_value(1)), 1) self.assertEqual(float(dragon.Tensor.from_value(1)), 1.) self.assertEqual(float(dragon.EagerTensor.from_value(1)), 1.) self.assertEqual(int(b.set_value(1)), 1) self.assertEqual(float(b), 1.) self.assertEqual(int(b.get_value()), 1) try: a.shape = 1 except TypeError: pass try: b.shape = (2, 3) except RuntimeError: pass try: b.dtype = 'float64' except RuntimeError: pass try: b = dragon.EagerTensor(0, 0) except ValueError: pass with dragon.name_scope('a'): a.name = 'a' self.assertEqual(a.name, 'a/a') with dragon.name_scope(''): b.name = 'b' self.assertEqual(b.name, 'b') b.requires_grad = True self.assertEqual(b.requires_grad, True)
def test_feed_tensor(self): w = dragon.Workspace() with w.as_default(): v1, v2 = dragon.EagerTensor(1), np.array(2) x = dragon.Tensor((), name='test_feed_tensor/x') w.feed_tensor(x, v1) self.assertEqual(int(x), 1) w.feed_tensor(x, v2) self.assertEqual(int(x), 2)
def test_dlpack_converter(self): data = np.array([0., 1., 2.], 'float32') with dragon.device('cpu'), dragon.eager_scope(): x = dragon.EagerTensor(data, copy=True) x_to_dlpack = dragon.dlpack.to_dlpack(x) x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack) self.assertEqual(x_from_dlpack.shape, data.shape) self.assertEqual(x_from_dlpack.dtype, str(data.dtype)) self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
def test_dlpack_converter_cuda(self): data = np.array([0., 1., 2.], 'float32') with dragon.device('cuda', 0), execution_context().mode('EAGER_MODE'): x = dragon.EagerTensor(data, copy=True) + 0 x_to_dlpack = dragon.dlpack.to_dlpack(x) x_from_dlpack = dragon.dlpack.from_dlpack(x_to_dlpack) self.assertEqual(x_from_dlpack.device.type, 'cuda') self.assertEqual(x_from_dlpack.device.index, 0) self.assertEqual(x_from_dlpack.shape, data.shape) self.assertEqual(x_from_dlpack.dtype, str(data.dtype)) self.assertLessEqual(np.abs(x_from_dlpack.numpy() - data).max(), 1e-5)
def test_create_function(self): a = dragon.Tensor((), dtype='int32').set_value(1) b = dragon.Tensor((), dtype='int32').set_value(2) y = a + 1 try: dragon.create_function(outputs=y, optimizer=dragon.optimizers.SGD()) except ValueError: pass try: dragon.create_function(outputs=dragon.EagerTensor(1)) except ValueError: pass try: f = dragon.create_function(outputs=y, givens={a: 1}) except ValueError: f = dragon.create_function(outputs=y, givens={a: b}) self.assertEqual(int(f()), 3)
def test_register_alias(self): w = dragon.Workspace() with w.as_default(): x = dragon.EagerTensor(1) w.register_alias(x.id, 'test_register_alias/y') self.assertEqual(int(w.fetch_tensor('test_register_alias/y')), 1)