def connect(self):
        """
        This function tries to connects to the selected embedder if it is not
        successful due to any server/connection error it switches to the
        local embedder and warns the user about that.
        """
        self.Warning.switched_local_embedder.clear()

        # try to connect to current embedder
        embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate'
        )

        if not embedder.is_local_embedder() and \
            not embedder.is_connected_to_server(use_hyper=False):
            # there is a problem with connecting to the server
            # switching to local embedder
            self.Warning.switched_local_embedder()
            del embedder  # remove current embedder
            self.cb_embedder_current_id = self.embedders.index("squeezenet")
            print(self.embedders[self.cb_embedder_current_id])
            embedder = ImageEmbedder(
                model=self.embedders[self.cb_embedder_current_id],
                layer='penultimate'
            )

        return embedder
class OWImageEmbedding(OWWidget):
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    class Inputs:
        images = Input('Images', Table)

    class Outputs:
        embeddings = Output('Embeddings', Table, default=True)
        skipped_images = Output('Skipped Images', Table)

    cb_image_attr_current_id = Setting(default=0)
    cb_embedder_current_id = Setting(default=0)

    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self.embedders = sorted(list(EMBEDDERS_INFO),
                                key=lambda k: EMBEDDERS_INFO[k]['order'])
        self._image_attributes = None
        self._input_data = None
        self._log = logging.getLogger(__name__)
        self._task = None
        self._setup_layout()
        self._image_embedder = None
        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.setBlocking(True)
        QTimer.singleShot(0, self._init_server_connection)

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(
            widget=widget_box,
            master=self,
            value='cb_image_attr_current_id',
            label='Image attribute:',
            orientation=Qt.Horizontal,
            callback=self._cb_image_attr_changed
        )

        self.cb_embedder = comboBox(
            widget=widget_box,
            master=self,
            value='cb_embedder_current_id',
            label='Embedder:',
            orientation=Qt.Horizontal,
            callback=self._cb_embedder_changed
        )
        names = [EMBEDDERS_INFO[e]['name'] +
                 (" (local)" if EMBEDDERS_INFO[e].get("is_local") else "")
                 for e in self.embedders]
        self.cb_embedder.setModel(VariableListModel(names))
        if not self.cb_embedder_current_id < len(self.embedders):
            self.cb_embedder_current_id = 0
        self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id)

        current_embedder = self.embedders[self.cb_embedder_current_id]
        self.embedder_info = widgetLabel(
            widget_box,
            EMBEDDERS_INFO[current_embedder]['description']
        )

        self.auto_commit_widget = auto_commit(
            widget=self.controlArea,
            master=self,
            value='_auto_apply',
            label='Apply',
            commit=self.commit
        )

        self.cancel_button = QPushButton(
            'Cancel',
            icon=self.style().standardIcon(QStyle.SP_DialogCancelButton),
        )
        self.cancel_button.clicked.connect(self.cancel)
        hbox = hBox(self.controlArea)
        hbox.layout().addWidget(self.cancel_button)
        self.cancel_button.setDisabled(True)

    def _init_server_connection(self):
        self.setBlocking(False)
        self._image_embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate'
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )

    @Inputs.images
    def set_data(self, data):
        if not data:
            self._input_data = None
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        self.input_data_info.setText(
            "Data with {:d} instances.".format(len(data)))

        self._cb_image_attr_changed()

    def _cb_image_attr_changed(self):
        self.commit()

    def _cb_embedder_changed(self):
        current_embedder = self.embedders[self.cb_embedder_current_id]
        self._image_embedder = ImageEmbedder(
            model=current_embedder,
            layer='penultimate'
        )
        self.embedder_info.setText(
            EMBEDDERS_INFO[current_embedder]['description'])
        if self._input_data:
            self.input_data_info.setText(
                "Data with {:d} instances.".format(len(self._input_data)))
            self.commit()
        else:
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
        self._set_server_info(self._image_embedder.is_connected_to_server())

    def commit(self):
        if self._task is not None:
            self.cancel()

        if self._image_embedder is None:
            self._set_server_info(connected=False)
            return

        if not self._image_attributes or self._input_data is None:
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        self._set_server_info(connected=True)
        self.cancel_button.setDisabled(False)
        self.cb_image_attr.setDisabled(True)
        self.cb_embedder.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()
        origin = file_paths_attr.attributes.get("origin", "")
        if urlparse(origin).scheme in ("http", "https", "ftp", "data") and \
                origin[-1] != "/":
            origin += "/"

        assert file_paths_attr.is_string
        assert file_paths.dtype == np.dtype('O')

        file_paths_mask = file_paths == file_paths_attr.Unknown
        file_paths_valid = file_paths[~file_paths_mask]
        for i, a in enumerate(file_paths_valid):
            urlparts = urlparse(a)
            if urlparts.scheme not in ("http", "https", "ftp", "data"):
                if urlparse(origin).scheme in ("http", "https", "ftp", "data"):
                    file_paths_valid[i] = urljoin(origin, a)
                else:
                    file_paths_valid[i] = os.path.join(origin, a)

        ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size))
        set_progress = qconcurrent.methodinvoke(
            self, "__progress_set", (float,))

        def advance(success=True):
            if success:
                set_progress(next(ticks))

        def cancel():
            task.future.cancel()
            task.cancelled = True
            task.embedder.set_canceled(True)

        embedder = self._image_embedder

        def run_embedding(paths):
            return embedder(
                file_paths=paths, image_processed_callback=advance)

        self.auto_commit_widget.setDisabled(True)
        self.progressBarInit(processEvents=None)
        self.progressBarSet(0.0, processEvents=None)
        self.setBlocking(True)

        f = self._executor.submit(run_embedding, file_paths_valid)
        f.add_done_callback(
            qconcurrent.methodinvoke(self, "__set_results", (object,)))

        task = self._task = namespace(
            file_paths_mask=file_paths_mask,
            file_paths_valid=file_paths_valid,
            file_paths=file_paths,
            embedder=embedder,
            cancelled=False,
            cancel=cancel,
            future=f,
        )
        self._log.debug("Starting embedding task for %i images",
                        file_paths.size)
        return

    @Slot(float)
    def __progress_set(self, value):
        assert self.thread() is QThread.currentThread()
        if self._task is not None:
            self.progressBarSet(value)

    @Slot(object)
    def __set_results(self, f):
        assert self.thread() is QThread.currentThread()
        if self._task is None or self._task.future is not f:
            self._log.info("Reaping stale task")
            return

        assert f.done()

        task, self._task = self._task, None
        self.auto_commit_widget.setDisabled(False)
        self.cancel_button.setDisabled(True)
        self.cb_image_attr.setDisabled(False)
        self.cb_embedder.setDisabled(False)
        self.progressBarFinished(processEvents=None)
        self.setBlocking(False)

        try:
            embeddings = f.result()
        except ConnectionError:
            self._log.exception("Error", exc_info=True)
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self._set_server_info(connected=False)
            return
        except Exception as err:
            self._log.exception("Error", exc_info=True)
            self.error(
                "\n".join(traceback.format_exception_only(type(err), err)))
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        assert self._input_data is not None
        assert len(self._input_data) == len(task.file_paths_mask)

        # Missing paths/urls were filtered out. Restore the full embeddings
        # array from information stored in task.file_path_mask ...
        embeddings_all = [None] * len(task.file_paths_mask)
        for i, embedding in zip(np.flatnonzero(~task.file_paths_mask),
                                embeddings):
            embeddings_all[i] = embedding
        embeddings_all = np.array(embeddings_all)
        self._send_output_signals(embeddings_all)

    def _send_output_signals(self, embeddings):
        embedded_images, skipped_images, num_skipped =\
            ImageEmbedder.prepare_output_data(self._input_data, embeddings)
        self.Outputs.embeddings.send(embedded_images)
        self.Outputs.skipped_images.send(skipped_images)
        if num_skipped is not 0:
            self.input_data_info.setText(
                "Data with {:d} instances, {:d} images skipped.".format(
                    len(self._input_data), num_skipped))

    def _set_server_info(self, connected):
        self.clear_messages()
        if self._image_embedder is None:
            return

        if connected:
            self.connection_info.setText("Connected to server.")
        elif self._image_embedder.is_local_embedder():
            self.connection_info.setText("Using local embedder.")
        else:
            self.connection_info.setText("Not connected to server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
        if self._image_embedder is not None:
            self._image_embedder.__exit__(None, None, None)

    def cancel(self):
        if self._task is not None:
            task, self._task = self._task, None
            task.cancel()
            # wait until done
            try:
                task.future.exception()
            except qconcurrent.CancelledError:
                pass

            self.auto_commit_widget.setDisabled(False)
            self.cancel_button.setDisabled(True)
            self.progressBarFinished(processEvents=None)
            self.setBlocking(False)
            self.cb_image_attr.setDisabled(False)
            self.cb_embedder.setDisabled(False)
            self._image_embedder.set_canceled(False)
            # reset the connection.
            connected = self._image_embedder.reconnect_to_server()
            self._set_server_info(connected=connected)
Пример #3
0
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

    def tearDown(self):
        self.embedder.clear_cache()
        logging.disable(logging.NOTSET)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        # server closes the connection
        self.embedder._server_connection.close()
        self.assertEqual(self.embedder.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder.clear_cache()

        self.embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder(self.single_example)

        ConnectionMock.side_effect = BrokenPipeError
        with self.assertRaises(ConnectionError):
            self.embedder(self.single_example)

    @patch.object(DummyHttp2Connection, 'get_response')
    def test_on_stream_reset_by_server(self, ConnectionMock):
        ConnectionMock.side_effect = StreamResetError
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder.reconnect_to_server()
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder(self.single_example)
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder(self.single_example)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.persist_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.clear_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._cache_dict), 1)
        embedder.persist_cache()

        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._cache_dict), 1)
        embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_JPG for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder(too_many_examples)
        assert_array_equal(res, true_res)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder(more_examples)
        self.assertEqual(res.shape, (5, 2))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_grayscale_image(self):
        res = self.embedder([_EXAMPLE_IMAGE_GRAYSCALE])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_tiff_image(self):
        res = self.embedder([_EXAMPLE_IMAGE_TIFF])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_embedding_cancelled(self):
        self.assertFalse(self.embedder.cancelled)
        self.embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder(self.single_example)

    def test_version(self):
        """
        Test if new version of a hyper library is published
        When this test start to fails remove temporary fix in http2_client
        marked with TODO
        """
        import hyper
        self.assertEqual(hyper.__version__, "0.7.0")
Пример #4
0
class OWImageEmbedding(OWWidget):
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    class Inputs:
        images = Input('Images', Table)

    class Outputs:
        embeddings = Output('Embeddings', Table, default=True)
        skipped_images = Output('Skipped Images', Table)

    cb_image_attr_current_id = Setting(default=0)
    cb_embedder_current_id = Setting(default=0)

    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self.embedders = sorted(list(EMBEDDERS_INFO),
                                key=lambda k: EMBEDDERS_INFO[k]['order'])
        self._image_attributes = None
        self._input_data = None
        self._log = logging.getLogger(__name__)
        self._task = None
        self._setup_layout()
        self._image_embedder = None
        self._executor = qconcurrent.ThreadExecutor(
            self, threadPool=QThreadPool(maxThreadCount=1)
        )
        self.setBlocking(True)
        QTimer.singleShot(0, self._init_server_connection)

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(
            widget=widget_box,
            master=self,
            value='cb_image_attr_current_id',
            label='Image attribute:',
            orientation=Qt.Horizontal,
            callback=self._cb_image_attr_changed
        )

        self.cb_embedder = comboBox(
            widget=widget_box,
            master=self,
            value='cb_embedder_current_id',
            label='Embedder:',
            orientation=Qt.Horizontal,
            callback=self._cb_embedder_changed
        )
        self.cb_embedder.setModel(VariableListModel(
            [EMBEDDERS_INFO[e]['name'] for e in self.embedders]))
        if not self.cb_embedder_current_id < len(self.embedders):
            self.cb_embedder_current_id = 0
        self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id)

        current_embedder = self.embedders[self.cb_embedder_current_id]
        self.embedder_info = widgetLabel(
            widget_box,
            EMBEDDERS_INFO[current_embedder]['description']
        )

        self.auto_commit_widget = auto_commit(
            widget=self.controlArea,
            master=self,
            value='_auto_apply',
            label='Apply',
            commit=self.commit
        )

        self.cancel_button = QPushButton(
            'Cancel',
            icon=self.style().standardIcon(QStyle.SP_DialogCancelButton),
        )
        self.cancel_button.clicked.connect(self.cancel)
        hbox = hBox(self.controlArea)
        hbox.layout().addWidget(self.cancel_button)
        self.cancel_button.setDisabled(True)

    def _init_server_connection(self):
        self.setBlocking(False)
        self._image_embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate'
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )

    @Inputs.images
    def set_data(self, data):
        if not data:
            self._input_data = None
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        self.input_data_info.setText(
            "Data with {:d} instances.".format(len(data)))

        self._cb_image_attr_changed()

    def _cb_image_attr_changed(self):
        self.commit()

    def _cb_embedder_changed(self):
        current_embedder = self.embedders[self.cb_embedder_current_id]
        self._image_embedder = ImageEmbedder(
            model=current_embedder,
            layer='penultimate'
        )
        self.embedder_info.setText(
            EMBEDDERS_INFO[current_embedder]['description'])
        if self._input_data:
            self.input_data_info.setText(
                "Data with {:d} instances.".format(len(self._input_data)))
            self.commit()
        else:
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)

    def commit(self):
        if self._task is not None:
            self.cancel()

        if self._image_embedder is None:
            self._set_server_info(connected=False)
            return

        if not self._image_attributes or self._input_data is None:
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        self._set_server_info(connected=True)
        self.cancel_button.setDisabled(False)
        self.cb_image_attr.setDisabled(True)
        self.cb_embedder.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()
        origin = file_paths_attr.attributes.get("origin", "")
        if urlparse(origin).scheme in ("http", "https", "ftp", "data") and \
                origin[-1] != "/":
            origin += "/"

        assert file_paths_attr.is_string
        assert file_paths.dtype == np.dtype('O')

        file_paths_mask = file_paths == file_paths_attr.Unknown
        file_paths_valid = file_paths[~file_paths_mask]
        for i, a in enumerate(file_paths_valid):
            urlparts = urlparse(a)
            if urlparts.scheme not in ("http", "https", "ftp", "data"):
                if urlparse(origin).scheme in ("http", "https", "ftp", "data"):
                    file_paths_valid[i] = urljoin(origin, a)
                else:
                    file_paths_valid[i] = os.path.join(origin, a)

        ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size))
        set_progress = qconcurrent.methodinvoke(
            self, "__progress_set", (float,))

        def advance(success=True):
            if success:
                set_progress(next(ticks))

        def cancel():
            task.future.cancel()
            task.cancelled = True
            task.embedder.cancelled = True

        embedder = self._image_embedder

        def run_embedding(paths):
            return embedder(
                file_paths=paths, image_processed_callback=advance)

        self.auto_commit_widget.setDisabled(True)
        self.progressBarInit(processEvents=None)
        self.progressBarSet(0.0, processEvents=None)
        self.setBlocking(True)

        f = self._executor.submit(run_embedding, file_paths_valid)
        f.add_done_callback(
            qconcurrent.methodinvoke(self, "__set_results", (object,)))

        task = self._task = namespace(
            file_paths_mask=file_paths_mask,
            file_paths_valid=file_paths_valid,
            file_paths=file_paths,
            embedder=embedder,
            cancelled=False,
            cancel=cancel,
            future=f,
        )
        self._log.debug("Starting embedding task for %i images",
                        file_paths.size)
        return

    @Slot(float)
    def __progress_set(self, value):
        assert self.thread() is QThread.currentThread()
        if self._task is not None:
            self.progressBarSet(value)

    @Slot(object)
    def __set_results(self, f):
        assert self.thread() is QThread.currentThread()
        if self._task is None or self._task.future is not f:
            self._log.info("Reaping stale task")
            return

        assert f.done()

        task, self._task = self._task, None
        self.auto_commit_widget.setDisabled(False)
        self.cancel_button.setDisabled(True)
        self.cb_image_attr.setDisabled(False)
        self.cb_embedder.setDisabled(False)
        self.progressBarFinished(processEvents=None)
        self.setBlocking(False)

        try:
            embeddings = f.result()
        except ConnectionError:
            self._log.exception("Error", exc_info=True)
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self._set_server_info(connected=False)
            return
        except Exception as err:
            self._log.exception("Error", exc_info=True)
            self.error(
                "\n".join(traceback.format_exception_only(type(err), err)))
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        assert self._input_data is not None
        assert len(self._input_data) == len(task.file_paths_mask)

        # Missing paths/urls were filtered out. Restore the full embeddings
        # array from information stored in task.file_path_mask ...
        embeddings_all = [None] * len(task.file_paths_mask)
        for i, embedding in zip(np.flatnonzero(~task.file_paths_mask),
                                embeddings):
            embeddings_all[i] = embedding
        embeddings_all = np.array(embeddings_all)
        self._send_output_signals(embeddings_all)

    def _send_output_signals(self, embeddings):
        embedded_images, skipped_images, num_skipped =\
            ImageEmbedder.prepare_output_data(self._input_data, embeddings)
        self.Outputs.embeddings.send(embedded_images)
        self.Outputs.skipped_images.send(skipped_images)
        if num_skipped is not 0:
            self.input_data_info.setText(
                "Data with {:d} instances, {:d} images skipped.".format(
                    len(self._input_data), num_skipped))

    def _set_server_info(self, connected):
        self.clear_messages()
        if connected:
            self.connection_info.setText("Connected to server.")
        else:
            self.connection_info.setText("No connection with server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
        if self._image_embedder is not None:
            self._image_embedder.__exit__(None, None, None)

    def cancel(self):
        if self._task is not None:
            task, self._task = self._task, None
            task.cancel()
            # wait until done
            try:
                task.future.exception()
            except qconcurrent.CancelledError:
                pass

            self.auto_commit_widget.setDisabled(False)
            self.cancel_button.setDisabled(True)
            self.progressBarFinished(processEvents=None)
            self.setBlocking(False)
            self.cb_image_attr.setDisabled(False)
            self.cb_embedder.setDisabled(False)
            self._image_embedder.cancelled = False
            # reset the connection.
            connected = self._image_embedder.reconnect_to_server()
            self._set_server_info(connected=connected)
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.embedder.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_0]

    def tearDown(self):
        self.embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        # server closes the connection
        self.embedder._server_connection.close()
        self.assertEqual(self.embedder.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder.clear_cache()

        self.embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder(self.single_example)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder.reconnect_to_server()
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder(self.single_example)
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder(self.single_example)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.persist_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.clear_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_0 for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder(too_many_examples)
        assert_array_equal(res, true_res)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        more_examples = [_EXAMPLE_IMAGE_0 for _ in range(5)]
        res = self.embedder(more_examples)
        self.assertEqual(res.shape, (5, 2))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
                server_port=80
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
                server_port=80
            )
Пример #6
0
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_server.clear_cache()
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_local.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

    def tearDown(self):
        self.embedder_server.clear_cache()
        logging.disable(logging.NOTSET)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        # server closes the connection
        self.embedder_server._embedder._server_connection.close()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder_server.clear_cache()

        self.embedder_server._embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

        ConnectionMock.side_effect = BrokenPipeError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

    @patch.object(DummyHttp2Connection, 'get_response')
    def test_on_stream_reset_by_server(self, ConnectionMock):
        ConnectionMock.side_effect = StreamResetError
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server.reconnect_to_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server(self.single_example)
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder_server(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder_server._embedder._cache._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server(self.single_example)
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server._embedder._cache.persist_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server.clear_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder._embedder._cache.persist_cache()

        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server._embedder._cache.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_local.is_connected_to_server(), False)
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder_server._embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_JPG for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder_server(too_many_examples)
        assert_array_equal(res, true_res)
        # no need to test it on local embedder since it does not work
        # in batches

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        # global embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_server(more_examples)
        self.assertEqual(res.shape, (5, 2))

        # local embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_local(more_examples)
        self.assertEqual(res.shape, (5, 1000))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        # test server embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

        # test local embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='squeezenet',
                layer='first',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_grayscale_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_GRAYSCALE])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_GRAYSCALE])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_tiff_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_TIFF])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_TIFF])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder_server._embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder_server._embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_embedding_cancelled(self):
        # test for the server embedders
        self.assertFalse(self.embedder_server._embedder.cancelled)
        self.embedder_server._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_server(self.single_example)

        # test for the local embedder
        self.assertFalse(self.embedder_local._embedder.cancelled)
        self.embedder_local._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_local(self.single_example)

    def test_table_online_data(self):
        data = Table("https://datasets.biolab.si/core/bone-healing.xlsx")
        emb, skipped, num_skiped = self.embedder_local(data, col="Image")

        self.assertIsNone(skipped)
        self.assertEqual(0, num_skiped)
        self.assertEqual(len(data), len(emb))
        self.assertTupleEqual((len(data), 1000), emb.X.shape)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_table_server_embedder(self):
        data = Table("https://datasets.biolab.si/core/bone-healing.xlsx")
        emb, skipped, num_skiped = self.embedder_server(data, col="Image")

        self.assertIsNone(skipped)
        self.assertEqual(0, num_skiped)
        self.assertEqual(len(data), len(emb))
        self.assertTupleEqual((len(data), 2), emb.X.shape)

    def test_table_local_data(self):
        str_var = StringVariable("Image")
        str_var.attributes["origin"] = path.dirname(
            path.abspath(__file__))
        data = Table(
            Domain([], [], metas=[str_var]),
            np.empty((3, 0)), np.empty((3, 0)),
            metas=[[_EXAMPLE_IMAGE_JPG],
                   [_EXAMPLE_IMAGE_TIFF],
                   [_EXAMPLE_IMAGE_GRAYSCALE]])

        emb, skipped, num_skiped = self.embedder_local(data, col="Image")

        self.assertIsNone(skipped)
        self.assertEqual(0, num_skiped)
        self.assertEqual(len(data), len(emb))
        self.assertTupleEqual((len(data), 1000), emb.X.shape)

    def test_table_skip(self):
        data = Table("https://datasets.biolab.si/core/bone-healing.xlsx")
        data.metas[0, 1] = "tralala"
        emb, skipped, num_skiped = self.embedder_local(data, col="Image")

        self.assertIsNotNone(skipped)
        self.assertEqual(1, num_skiped)
        self.assertEqual(len(data) - 1, len(emb))
        self.assertTupleEqual((len(data) - 1, 1000), emb.X.shape)
Пример #7
0
class OWImageEmbedding(OWWidget):
    # todo: implement embedding in a non-blocking manner
    # todo: implement stop running task action
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    inputs = [(_Input.IMAGES, Table, 'set_data')]
    outputs = [(_Output.EMBEDDINGS, Table, Default),
               (_Output.SKIPPED_IMAGES, Table)]

    cb_image_attr_current_id = Setting(default=0)
    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self._image_attributes = None
        self._input_data = None

        self._setup_layout()
        self._image_embedder = None
        QTimer.singleShot(0, self._init_server_connection)

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(widget=widget_box,
                                      master=self,
                                      value='cb_image_attr_current_id',
                                      label='Image attribute:',
                                      orientation=Qt.Horizontal,
                                      callback=self._cb_image_attr_changed)

        self.auto_commit_widget = auto_commit(widget=self.controlArea,
                                              master=self,
                                              value='_auto_apply',
                                              label='Apply',
                                              checkbox_label='Auto Apply',
                                              commit=self.commit)

    def _init_server_connection(self):
        self._image_embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
        )
        self._set_server_info(self._image_embedder.is_connected_to_server())

    def set_data(self, data):
        if data is None:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = self._filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes.".
                format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        input_data_info_text = "Data with {:d} instances.".format(len(data))
        self.input_data_info.setText(input_data_info_text)

        self._cb_image_attr_changed()

    @staticmethod
    def _filter_image_attributes(data):
        metas = data.domain.metas
        return [m for m in metas if m.attributes.get('type') == 'image']

    def _cb_image_attr_changed(self):
        if self._auto_apply:
            self.commit()

    def commit(self):
        if not self._image_attributes or not self._input_data:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            return

        self._set_server_info(connected=True)
        self.auto_commit_widget.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()

        with self.progressBar(len(file_paths)) as progress:
            try:
                embeddings = self._image_embedder(
                    file_paths=file_paths,
                    image_processed_callback=lambda: progress.advance())
            except ConnectionError:
                self.send(_Output.EMBEDDINGS, None)
                self.send(_Output.SKIPPED_IMAGES, None)
                self._set_server_info(connected=False)
                self.auto_commit_widget.setDisabled(False)
                return

        self._send_output_signals(embeddings)
        self.auto_commit_widget.setDisabled(False)

    def _send_output_signals(self, embeddings):
        skipped_images_bool = np.array([x is None for x in embeddings])

        if np.any(skipped_images_bool):
            skipped_images = self._input_data[skipped_images_bool]
            skipped_images = Table(skipped_images)
            skipped_images.ids = self._input_data.ids[skipped_images_bool]
            self.send(_Output.SKIPPED_IMAGES, skipped_images)
        else:
            self.send(_Output.SKIPPED_IMAGES, None)

        embedded_images_bool = np.logical_not(skipped_images_bool)

        if np.any(embedded_images_bool):
            embedded_images = self._input_data[embedded_images_bool]

            embeddings = embeddings[embedded_images_bool]
            embeddings = np.stack(embeddings)

            embedded_images = self._construct_output_data_table(
                embedded_images, embeddings)
            embedded_images.ids = self._input_data.ids[embedded_images_bool]
            self.send(_Output.EMBEDDINGS, embedded_images)
        else:
            self.send(_Output.EMBEDDINGS, None)

    @staticmethod
    def _construct_output_data_table(embedded_images, embeddings):
        X = np.hstack((embedded_images.X, embeddings))
        Y = embedded_images.Y

        dimensions = range(embeddings.shape[1])
        attributes = [
            ContinuousVariable('n{:d}'.format(d)) for d in dimensions
        ]
        attributes = list(embedded_images.domain.attributes) + attributes

        domain = Domain(attributes=attributes,
                        class_vars=embedded_images.domain.class_vars,
                        metas=embedded_images.domain.metas)

        return Table(domain, X, Y, embedded_images.metas)

    def _set_server_info(self, connected):
        self.clear_messages()
        if connected:
            self.connection_info.setText("Connected to server.")
        else:
            self.connection_info.setText("No connection with server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._image_embedder.__exit__(None, None, None)
class OWImageEmbedding(OWWidget):
    # todo: implement embedding in a non-blocking manner
    # todo: implement stop running task action
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    inputs = [(_Input.IMAGES, Table, 'set_data')]
    outputs = [
        (_Output.EMBEDDINGS, Table, Default),
        (_Output.SKIPPED_IMAGES, Table)
    ]

    cb_image_attr_current_id = Setting(default=0)
    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self._image_attributes = None
        self._input_data = None

        self._setup_layout()

        self._image_embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(
            widget=widget_box,
            master=self,
            value='cb_image_attr_current_id',
            label='Image attribute:',
            orientation=Qt.Horizontal,
            callback=self._cb_image_attr_changed
        )

        self.auto_commit_widget = auto_commit(
            widget=self.controlArea,
            master=self,
            value='_auto_apply',
            label='Apply',
            checkbox_label='Auto Apply',
            commit=self.commit
        )

    def set_data(self, data):
        if data is None:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = self._filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        input_data_info_text = "Data with {:d} instances.".format(len(data))
        self.input_data_info.setText(input_data_info_text)

        self._cb_image_attr_changed()

    @staticmethod
    def _filter_image_attributes(data):
        metas = data.domain.metas
        return [m for m in metas if m.attributes.get('type') == 'image']

    def _cb_image_attr_changed(self):
        if self._auto_apply:
            self.commit()

    def commit(self):
        if not self._image_attributes or not self._input_data:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            return

        self._set_server_info(connected=True)
        self.auto_commit_widget.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()

        with self.progressBar(len(file_paths)) as progress:
            try:
                embeddings = self._image_embedder(
                    file_paths=file_paths,
                    image_processed_callback=lambda: progress.advance()
                )
            except ConnectionError:
                self.send(_Output.EMBEDDINGS, None)
                self.send(_Output.SKIPPED_IMAGES, None)
                self._set_server_info(connected=False)
                self.auto_commit_widget.setDisabled(False)
                return

        self._send_output_signals(embeddings)
        self.auto_commit_widget.setDisabled(False)

    def _send_output_signals(self, embeddings):
        skipped_images_bool = np.array([x is None for x in embeddings])

        if np.any(skipped_images_bool):
            skipped_images = self._input_data[skipped_images_bool]
            skipped_images = Table(skipped_images)
            skipped_images.ids = self._input_data.ids[skipped_images_bool]
            self.send(_Output.SKIPPED_IMAGES, skipped_images)
        else:
            self.send(_Output.SKIPPED_IMAGES, None)

        embedded_images_bool = np.logical_not(skipped_images_bool)

        if np.any(embedded_images_bool):
            embedded_images = self._input_data[embedded_images_bool]

            embeddings = embeddings[embedded_images_bool]
            embeddings = np.stack(embeddings)

            embedded_images = self._construct_output_data_table(
                embedded_images,
                embeddings
            )
            embedded_images.ids = self._input_data.ids[embedded_images_bool]
            self.send(_Output.EMBEDDINGS, embedded_images)
        else:
            self.send(_Output.EMBEDDINGS, None)

    @staticmethod
    def _construct_output_data_table(embedded_images, embeddings):
        X = np.hstack((embedded_images.X, embeddings))
        Y = embedded_images.Y

        dimensions = range(embeddings.shape[1])
        attributes = [ContinuousVariable('n{:d}'.format(d)) for d in dimensions]
        attributes = list(embedded_images.domain.attributes) + attributes

        domain = Domain(
            attributes=attributes,
            class_vars=embedded_images.domain.class_vars,
            metas=embedded_images.domain.metas
        )

        return Table(domain, X, Y, embedded_images.metas)

    def _set_server_info(self, connected):
        self.clear_messages()
        if connected:
            self.connection_info.setText("Connected to server.")
        else:
            self.connection_info.setText("No connection with server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._image_embedder.__exit__(None, None, None)
Пример #9
0
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.embedder.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

    def tearDown(self):
        self.embedder.clear_cache()
        logging.disable(logging.NOTSET)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        # server closes the connection
        self.embedder._server_connection.close()
        self.assertEqual(self.embedder.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder.clear_cache()

        self.embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder(self.single_example)

    @patch.object(DummyHttp2Connection, 'get_response')
    def test_on_stream_reset_by_server(self, ConnectionMock):
        ConnectionMock.side_effect = StreamResetError
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder.reconnect_to_server()
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder(self.single_example)
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch.object(DummyHttp2Connection, 'get_response',
                  lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({
            'wrong_key': None
        }).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder(self.single_example)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.persist_cache()
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.clear_cache()
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_JPG for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder(too_many_examples)
        assert_array_equal(res, true_res)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder(more_examples)
        self.assertEqual(res.shape, (5, 2))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(model='invalid_model',
                                          layer='penultimate',
                                          server_url='example.com',
                                          server_port=80)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(model='inception-v3',
                                          layer='first',
                                          server_url='example.com',
                                          server_port=80)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_grayscale_image(self):
        res = self.embedder([_EXAMPLE_IMAGE_GRAYSCALE])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_tiff_image(self):
        res = self.embedder([_EXAMPLE_IMAGE_TIFF])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)
class OWImageEmbedding(OWWidget):
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    inputs = [(_Input.IMAGES, Table, 'set_data')]
    outputs = [(_Output.EMBEDDINGS, Table, Default),
               (_Output.SKIPPED_IMAGES, Table)]

    cb_image_attr_current_id = Setting(default=0)
    cb_embedder_current_id = Setting(default=0)

    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self.embedders = sorted(list(EMBEDDERS_INFO))
        self._image_attributes = None
        self._input_data = None
        self._log = logging.getLogger(__name__)
        self._task = None
        self._setup_layout()
        self._image_embedder = None
        self._executor = qconcurrent.ThreadExecutor(
            self, threadPool=QThreadPool(maxThreadCount=1))
        self.setBlocking(True)
        QTimer.singleShot(0, self._init_server_connection)

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(widget=widget_box,
                                      master=self,
                                      value='cb_image_attr_current_id',
                                      label='Image attribute:',
                                      orientation=Qt.Horizontal,
                                      callback=self._cb_image_attr_changed)

        self.cb_embedder = comboBox(widget=widget_box,
                                    master=self,
                                    value='cb_embedder_current_id',
                                    label='Embedder:',
                                    orientation=Qt.Horizontal,
                                    callback=self._cb_embedder_changed)
        self.cb_embedder.setModel(
            VariableListModel(
                [EMBEDDERS_INFO[e]['name'] for e in self.embedders]))
        if not self.cb_embedder_current_id < len(self.embedders):
            self.cb_embedder_current_id = 0
        self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id)

        current_embedder = self.embedders[self.cb_embedder_current_id]
        self.embedder_info = widgetLabel(
            widget_box, EMBEDDERS_INFO[current_embedder]['description'])

        self.auto_commit_widget = auto_commit(widget=self.controlArea,
                                              master=self,
                                              value='_auto_apply',
                                              label='Apply',
                                              commit=self.commit)

        self.cancel_button = QPushButton(
            'Cancel',
            icon=self.style().standardIcon(QStyle.SP_DialogCancelButton),
        )
        self.cancel_button.clicked.connect(self.cancel)
        hbox = hBox(self.controlArea)
        hbox.layout().addWidget(self.cancel_button)
        self.cancel_button.hide()

    def _init_server_connection(self):
        self.setBlocking(False)
        self._image_embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate')
        self._set_server_info(self._image_embedder.is_connected_to_server())

    def set_data(self, data):
        if not data:
            self._input_data = None
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = self._filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes.".
                format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        input_data_info_text = "Data with {:d} instances.".format(len(data))
        self.input_data_info.setText(input_data_info_text)

        self._cb_image_attr_changed()

    @staticmethod
    def _filter_image_attributes(data):
        metas = data.domain.metas
        return [m for m in metas if m.attributes.get('type') == 'image']

    def _cb_image_attr_changed(self):
        self.commit()

    def _cb_embedder_changed(self):
        current_embedder = self.embedders[self.cb_embedder_current_id]
        self._image_embedder = ImageEmbedder(model=current_embedder,
                                             layer='penultimate')
        self.embedder_info.setText(
            EMBEDDERS_INFO[current_embedder]['description'])
        self.commit()

    def commit(self):
        if self._task is not None:
            self.cancel()

        if self._image_embedder is None:
            self._set_server_info(connected=False)
            return

        if not self._image_attributes or self._input_data is None:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            return

        self._set_server_info(connected=True)
        self.cancel_button.show()
        self.cb_image_attr.setDisabled(True)
        self.cb_embedder.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()
        assert file_paths_attr.is_string
        assert file_paths.dtype == np.dtype('O')

        file_paths_mask = file_paths == file_paths_attr.Unknown
        file_paths_valid = file_paths[~file_paths_mask]

        ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size))
        set_progress = qconcurrent.methodinvoke(self, "__progress_set",
                                                (float, ))

        def advance():
            set_progress(next(ticks))

        def cancel():
            task.future.cancel()
            task.cancelled = True
            task.embedder.cancelled = True

        embedder = self._image_embedder

        def run_embedding(paths):
            return embedder(file_paths=paths, image_processed_callback=advance)

        self.auto_commit_widget.setDisabled(True)
        self.progressBarInit(processEvents=None)
        self.progressBarSet(0.0, processEvents=None)
        self.setBlocking(True)

        f = self._executor.submit(run_embedding, file_paths_valid)
        f.add_done_callback(
            qconcurrent.methodinvoke(self, "__set_results", (object, )))

        task = self._task = namespace(
            file_paths_mask=file_paths_mask,
            file_paths_valid=file_paths_valid,
            file_paths=file_paths,
            embedder=embedder,
            cancelled=False,
            cancel=cancel,
            future=f,
        )
        self._log.debug("Starting embedding task for %i images",
                        file_paths.size)
        return

    @Slot(float)
    def __progress_set(self, value):
        assert self.thread() is QThread.currentThread()
        if self._task is not None:
            self.progressBarSet(value)

    @Slot(object)
    def __set_results(self, f):
        assert self.thread() is QThread.currentThread()
        if self._task is None or self._task.future is not f:
            self._log.info("Reaping stale task")
            return

        assert f.done()

        task, self._task = self._task, None
        self.auto_commit_widget.setDisabled(False)
        self.cancel_button.hide()
        self.cb_image_attr.setDisabled(False)
        self.cb_embedder.setDisabled(False)
        self.progressBarFinished(processEvents=None)
        self.setBlocking(False)

        try:
            embeddings = f.result()
        except ConnectionError:
            self._log.exception("Error", exc_info=True)
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            self._set_server_info(connected=False)
            return
        except Exception as err:
            self._log.exception("Error", exc_info=True)
            self.error("\n".join(
                traceback.format_exception_only(type(err), err)))
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            return

        assert self._input_data is not None
        assert len(self._input_data) == len(task.file_paths_mask)

        # Missing paths/urls were filtered out. Restore the full embeddings
        # array from information stored in task.file_path_mask ...
        embeddings_all = [None] * len(task.file_paths_mask)
        for i, embedding in zip(np.flatnonzero(~task.file_paths_mask),
                                embeddings):
            embeddings_all[i] = embedding
        embeddings_all = np.array(embeddings_all)
        self._send_output_signals(embeddings_all)

    def _send_output_signals(self, embeddings):
        skipped_images_bool = np.array([x is None for x in embeddings])

        if np.any(skipped_images_bool):
            skipped_images = self._input_data[skipped_images_bool]
            skipped_images = Table(skipped_images)
            skipped_images.ids = self._input_data.ids[skipped_images_bool]
            self.send(_Output.SKIPPED_IMAGES, skipped_images)
        else:
            self.send(_Output.SKIPPED_IMAGES, None)

        embedded_images_bool = np.logical_not(skipped_images_bool)

        if np.any(embedded_images_bool):
            embedded_images = self._input_data[embedded_images_bool]

            embeddings = embeddings[embedded_images_bool]
            embeddings = np.stack(embeddings)

            embedded_images = self._construct_output_data_table(
                embedded_images, embeddings)
            embedded_images.ids = self._input_data.ids[embedded_images_bool]
            self.send(_Output.EMBEDDINGS, embedded_images)
        else:
            self.send(_Output.EMBEDDINGS, None)

    @staticmethod
    def _construct_output_data_table(embedded_images, embeddings):
        X = np.hstack((embedded_images.X, embeddings))
        Y = embedded_images.Y

        attributes = [
            ContinuousVariable.make('n{:d}'.format(d))
            for d in range(embeddings.shape[1])
        ]
        attributes = list(embedded_images.domain.attributes) + attributes

        domain = Domain(attributes=attributes,
                        class_vars=embedded_images.domain.class_vars,
                        metas=embedded_images.domain.metas)

        return Table(domain, X, Y, embedded_images.metas)

    def _set_server_info(self, connected):
        self.clear_messages()
        if connected:
            self.connection_info.setText("Connected to server.")
        else:
            self.connection_info.setText("No connection with server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
        self._image_embedder.__exit__(None, None, None)

    def cancel(self):
        if self._task is not None:
            task, self._task = self._task, None
            task.cancel()
            # wait until done
            try:
                task.future.exception()
            except qconcurrent.CancelledError:
                pass

            self.auto_commit_widget.setDisabled(False)
            self.cancel_button.hide()
            self.progressBarFinished(processEvents=None)
            self.setBlocking(False)
            self.cb_image_attr.setDisabled(False)
            self.cb_embedder.setDisabled(False)
            self._image_embedder.cancelled = False
            # reset the connection.
            connected = self._image_embedder.reconnect_to_server()
            self._set_server_info(connected=connected)
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_server.clear_cache()
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_local.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

    def tearDown(self):
        self.embedder_server.clear_cache()
        logging.disable(logging.NOTSET)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        # server closes the connection
        self.embedder_server._embedder._server_connection.close()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder_server.clear_cache()

        self.embedder_server._embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

        ConnectionMock.side_effect = BrokenPipeError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

    @patch.object(DummyHttp2Connection, 'get_response')
    def test_on_stream_reset_by_server(self, ConnectionMock):
        ConnectionMock.side_effect = StreamResetError
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server.reconnect_to_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server(self.single_example)
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder_server(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder_server._embedder._cache._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server(self.single_example)
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server._embedder._cache.persist_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server.clear_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder._embedder._cache.persist_cache()

        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server._embedder._cache.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_local.is_connected_to_server(), False)
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder_server._embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_JPG for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder_server(too_many_examples)
        assert_array_equal(res, true_res)
        # no need to test it on local embedder since it does not work
        # in batches

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        # global embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_server(more_examples)
        self.assertEqual(res.shape, (5, 2))

        # local embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_local(more_examples)
        self.assertEqual(res.shape, (5, 1000))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        # test server embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

        # test local embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='squeezenet',
                layer='first',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_grayscale_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_GRAYSCALE])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_GRAYSCALE])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_tiff_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_TIFF])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_TIFF])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder_server._embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder_server._embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_embedding_cancelled(self):
        # test for the server embedders
        self.assertFalse(self.embedder_server._embedder.cancelled)
        self.embedder_server._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_server(self.single_example)

        # test for the local embedder
        self.assertFalse(self.embedder_local._embedder.cancelled)
        self.embedder_local._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_local(self.single_example)