Ejemplo n.º 1
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, 'CPUExecutionProvider', {})

        torch_ort._register_provider_lib(
            'TestExecutionProvider', self.get_test_execution_provider_path(),
            'ProviderEntryPoint')
        torch_ort.set_device(1, 'TestExecutionProvider', {
            'device_id': '0',
            'some_config': 'val'
        })
        ort_device = torch_ort.device(1)
    def test_register_custom_eps(self):
        C._register_provider_lib('TestExecutionProvider',
                                 self.get_test_execution_provider_path(),
                                 {'some_config': 'val'})

        assert 'TestExecutionProvider' in C.get_available_providers()

        this = os.path.dirname(__file__)
        custom_op_model = os.path.join(this, "testdata",
                                       "custom_execution_provider_library",
                                       "test_model.onnx")
        if not os.path.exists(custom_op_model):
            raise FileNotFoundError(
                "Unable to find '{0}'".format(custom_op_model))

        session_options = C.get_default_session_options()
        sess = C.InferenceSession(session_options, custom_op_model, True, True)
        sess.initialize_session(['TestExecutionProvider'], [{
            'device_id': '0'
        }], set())
        print(
            "Created session with customize execution provider successfully!")
Ejemplo n.º 3
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, "CPUExecutionProvider", {})

        torch_ort._register_provider_lib(
            "TestExecutionProvider", self.get_test_execution_provider_path(),
            {})
        # capture std out
        with OutputGrabber() as out:
            torch_ort.set_device(1, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext
        with OutputGrabber() as out:
            torch_ort.set_device(2, "TestExecutionProvider", {
                "device_id": "1",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 1, some_option: val" in out.capturedtext
        # test the reusing EP instance
        with OutputGrabber() as out:
            torch_ort.set_device(3, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" not in out.capturedtext
        # test clear training ep instance pool
        torch_ort.clear_training_ep_instances()
        with OutputGrabber() as out:
            torch_ort.set_device(3, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext
Ejemplo n.º 4
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, 'CPUExecutionProvider', {})

        torch_ort._register_provider_lib(
            'TestExecutionProvider', self.get_test_execution_provider_path(),
            {})
        # capture std out
        with OutputGrabber() as out:
            torch_ort.set_device(1, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
        with OutputGrabber() as out:
            torch_ort.set_device(2, 'TestExecutionProvider', {
                'device_id': '1',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext
        # test the reusing EP instance
        with OutputGrabber() as out:
            torch_ort.set_device(3, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext
        # test clear training ep instance pool
        torch_ort.clear_training_ep_instances()
        with OutputGrabber() as out:
            torch_ort.set_device(3, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext