def test_static_invoke(self): ctx = pyiree.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") f = ctx.modules.arithmetic["simple_mul"] arg0 = np.array([1., 2., 3., 4.], dtype=np.float32) arg1 = np.array([4., 5., 6., 7.], dtype=np.float32) results = f(arg0, arg1) np.testing.assert_allclose(results, [4., 10., 18., 28.])
def test_custom_dynamic(self): ctx = pyiree.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) self.assertEqual(ctx.modules.arithmetic.name, "arithmetic") f = ctx.modules.arithmetic["simple_mul"] f_repr = repr(f) print(f_repr) self.assertRegex( f_repr, re.escape( "(Buffer<float32[4]>, Buffer<float32[4]>) -> (Buffer<float32[4]>)" ))
def test_duplicate_module(self): ctx = pyiree.SystemContext() self.assertTrue(ctx.is_dynamic) ctx.add_module(create_simple_mul_module()) with self.assertRaisesRegex(ValueError, "arithmetic"): ctx.add_module(create_simple_mul_module())
def test_empty_static(self): ctx = pyiree.SystemContext(modules=()) self.assertFalse(ctx.is_dynamic) self.assertIn("hal", ctx.modules) self.assertEqual(ctx.modules.hal.name, "hal")
def test_empty_dynamic(self): ctx = pyiree.SystemContext() self.assertTrue(ctx.is_dynamic) self.assertIn("hal", ctx.modules) self.assertEqual(ctx.modules.hal.name, "hal")