Пример #1
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, 'CPUExecutionProvider', {})

        torch_ort._register_provider_lib(
            'TestExecutionProvider', self.get_test_execution_provider_path(),
            'ProviderEntryPoint')
        torch_ort.set_device(1, 'TestExecutionProvider', {
            'device_id': '0',
            'some_config': 'val'
        })
        ort_device = torch_ort.device(1)
Пример #2
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, "CPUExecutionProvider", {})

        torch_ort._register_provider_lib(
            "TestExecutionProvider", self.get_test_execution_provider_path(),
            {})
        # capture std out
        with OutputGrabber() as out:
            torch_ort.set_device(1, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext
        with OutputGrabber() as out:
            torch_ort.set_device(2, "TestExecutionProvider", {
                "device_id": "1",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 1, some_option: val" in out.capturedtext
        # test the reusing EP instance
        with OutputGrabber() as out:
            torch_ort.set_device(3, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" not in out.capturedtext
        # test clear training ep instance pool
        torch_ort.clear_training_ep_instances()
        with OutputGrabber() as out:
            torch_ort.set_device(3, "TestExecutionProvider", {
                "device_id": "0",
                "some_config": "val"
            })
            ort_device = torch_ort.device(1)
        assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext
Пример #3
0
    def test_import_custom_eps(self):
        torch_ort.set_device(0, 'CPUExecutionProvider', {})

        torch_ort._register_provider_lib(
            'TestExecutionProvider', self.get_test_execution_provider_path(),
            {})
        # capture std out
        with OutputGrabber() as out:
            torch_ort.set_device(1, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext
        with OutputGrabber() as out:
            torch_ort.set_device(2, 'TestExecutionProvider', {
                'device_id': '1',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext
        # test the reusing EP instance
        with OutputGrabber() as out:
            torch_ort.set_device(3, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext
        # test clear training ep instance pool
        torch_ort.clear_training_ep_instances()
        with OutputGrabber() as out:
            torch_ort.set_device(3, 'TestExecutionProvider', {
                'device_id': '0',
                'some_config': 'val'
            })
            ort_device = torch_ort.device(1)
        assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext