Exemplo n.º 1
0
    def test_add_one(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        comp_val = asyncio.get_event_loop().run_until_complete(
            ex.create_value(comp))

        self.assertEqual(str(comp_val.type_signature), '(float32 -> float32)')
        self.assertTrue(callable(comp_val.internal_representation))
        # NOTE: The internal representation is a functions that takes a parameter
        # kwarg and returns a dict with a 'result' key.
        result = comp_val.internal_representation(
            parameter=np.float32(5.0))['result']
        self.assertEqual(result, 6.0)

        arg_val = asyncio.get_event_loop().run_until_complete(
            ex.create_value(10.0, tf.float32))
        result_val = asyncio.get_event_loop().run_until_complete(
            ex.create_call(comp_val, arg_val))

        self.assertIsInstance(result_val, executor.IreeValue)
        self.assertEqual(str(result_val.type_signature), 'float32')
        result = asyncio.get_event_loop().run_until_complete(
            result_val.compute())
        self.assertEqual(result, 11.0)
Exemplo n.º 2
0
    def test_constant_computation(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)

        @computations.tf_computation
        def comp():
            return 1000.0

        comp_val = asyncio.get_event_loop().run_until_complete(
            ex.create_value(comp))

        self.assertIsInstance(comp_val, executor.IreeValue)
        self.assertEqual(str(comp_val.type_signature), '( -> float32)')
        self.assertTrue(callable(comp_val.internal_representation))
        # NOTE: The internal representation is a functions that takes a parameter
        # kwarg and returns a dict with a 'result' key.
        result = comp_val.internal_representation()['result']
        self.assertEqual(result, 1000.0)

        with self.assertRaises(TypeError):
            asyncio.get_event_loop().run_until_complete(comp_val.compute())

        result_val = asyncio.get_event_loop().run_until_complete(
            ex.create_call(comp_val))
        self.assertIsInstance(result_val, executor.IreeValue)
        self.assertEqual(str(result_val.type_signature), 'float32')
        self.assertIsInstance(result_val.internal_representation, np.float32)
        self.assertEqual(result_val.internal_representation, 1000.0)
        result = asyncio.get_event_loop().run_until_complete(
            result_val.compute())
        self.assertEqual(result, 1000.0)
Exemplo n.º 3
0
    def test_float(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        float_val = asyncio.run(ex.create_value(10.0, tf.float32))

        self.assertIsInstance(float_val, executor.IreeValue)
        self.assertEqual(str(float_val.type_signature), 'float32')
        self.assertIsInstance(float_val.internal_representation, np.float32)
        self.assertEqual(float_val.internal_representation, 10.0)
        result = asyncio.run(float_val.compute())
        self.assertEqual(result, 10.0)
Exemplo n.º 4
0
    def test_as_default_executor(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        default_executor.set_default_executor(
            executor_factory.create_executor_factory(lambda _: ex))

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
Exemplo n.º 5
0
    def test_as_default_context(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        factory = executor_stacks.ResourceManagingExecutorFactory(
            executor_stack_fn=lambda _: ex)
        context = execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)