def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_local.is_connected_to_server(), False)
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)
Esempio n. 2
0
 def test_invalid_layer(self):
     with self.assertRaises(ValueError):
         self.embedder = ImageEmbedder(
             model='inception-v3',
             layer='first',
             server_url='example.com',
         )
Esempio n. 3
0
 def setUp(self):
     logging.disable(logging.CRITICAL)
     self.embedder_server = ImageEmbedder(model='inception-v3', )
     self.embedder_server.clear_cache()
     self.embedder_local = ImageEmbedder(model='squeezenet', )
     self.embedder_local.clear_cache()
     self.single_example = [_EXAMPLE_IMAGE_JPG]
Esempio n. 4
0
    def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_local.is_connected_to_server(), False)
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)
 def _cb_embedder_changed(self):
     current_embedder = self.embedders[self.cb_embedder_current_id]
     self._image_embedder = ImageEmbedder(model=current_embedder,
                                          layer='penultimate')
     self.embedder_info.setText(
         EMBEDDERS_INFO[current_embedder]['description'])
     self.commit()
Esempio n. 6
0
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder._embedder._cache.persist_cache()

        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server._embedder._cache.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder.clear_cache()
Esempio n. 7
0
 def test_invalid_model(self):
     with self.assertRaises(ValueError):
         self.embedder = ImageEmbedder(
             model='invalid_model',
             layer='penultimate',
             server_url='example.com',
         )
Esempio n. 8
0
 def setUp(self):
     logging.disable(logging.CRITICAL)
     self.embedder = ImageEmbedder(model='inception-v3',
                                   layer='penultimate',
                                   server_url='example.com',
                                   server_port=80)
     self.embedder.clear_cache()
     self.single_example = [_EXAMPLE_IMAGE_JPG]
Esempio n. 9
0
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder_server._embedder.server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder_server = ImageEmbedder(model='inception-v3', )
        self.assertTrue(self.embedder_server._embedder.server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']
Esempio n. 10
0
 def _init_server_connection(self):
     self.setBlocking(False)
     self._image_embedder = ImageEmbedder(
         model=self.embedders[self.cb_embedder_current_id],
         layer='penultimate'
     )
     self._set_server_info(
         self._image_embedder.is_connected_to_server()
     )
Esempio n. 11
0
def run_embedding(
    images: Table,
    file_paths_attr: Variable,
    embedder_name: str,
    state: TaskState,
) -> Result:
    """
    Run the embedding process

    Parameters
    ----------
    images
        Data table with images to embed.
    file_paths_attr
        The column of the table with images.
    embedder_name
        The name of selected embedder.
    state
        State object used for controlling and progress.

    Returns
    -------
    The object that holds embedded images, skipped images, and number
    of skipped images.
    """
    embedder = ImageEmbedder(model=embedder_name)

    file_paths = images[:, file_paths_attr].metas.flatten()

    file_paths_mask = file_paths == file_paths_attr.Unknown
    file_paths_valid = file_paths[~file_paths_mask]

    # init progress bar and fuction
    ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size))

    def advance(success=True):
        if state.is_interruption_requested():
            embedder.set_canceled()
        if success:
            state.set_progress_value(next(ticks))

    try:
        emb, skip, n_skip = embedder(images,
                                     col=file_paths_attr,
                                     callback=advance)
    except EmbeddingConnectionError:
        # recompute ticks to go from current state to 100
        ticks = iter(np.linspace(next(ticks), 100.0, file_paths_valid.size))

        state.set_partial_result("squeezenet")
        embedder = ImageEmbedder(model="squeezenet")
        emb, skip, n_skip = embedder(images,
                                     col=file_paths_attr,
                                     callback=advance)

    return Result(embedding=emb, skip_images=skip, num_skipped=n_skip)
Esempio n. 12
0
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 1)
    def test_server_url_env_var(self):
        url_value = "http://example.com"
        self.embedder_server([_EXAMPLE_IMAGE_JPG])  # to init server embedder
        self.assertTrue(self.embedder_server._embedder.server_url != url_value)

        environ["ORANGE_EMBEDDING_API_URL"] = url_value
        self.embedder_server = ImageEmbedder(model="inception-v3")
        self.embedder_server([_EXAMPLE_IMAGE_JPG])  # to init server embedder
        self.assertTrue(self.embedder_server._embedder.server_url == url_value)
        del environ["ORANGE_EMBEDDING_API_URL"]
 def _cb_embedder_changed(self):
     current_embedder = self.embedders[self.cb_embedder_current_id]
     self._image_embedder = ImageEmbedder(
         model=current_embedder,
         layer='penultimate'
     )
     self.embedder_info.setText(
         EMBEDDERS_INFO[current_embedder]['description'])
     self.input_data_info.setText(
         "Data with {:d} instances.".format(len(self._input_data)))
     self.commit()
Esempio n. 15
0
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']
 def _cb_embedder_changed(self):
     current_embedder = self.embedders[self.cb_embedder_current_id]
     self._image_embedder = ImageEmbedder(model=current_embedder,
                                          layer='penultimate')
     self.embedder_info.setText(
         EMBEDDERS_INFO[current_embedder]['description'])
     if self._input_data:
         self.input_data_info.setText("Data with {:d} instances.".format(
             len(self._input_data)))
         self.commit()
     else:
         self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
     self._set_server_info(self._image_embedder.is_connected_to_server())
 def setUp(self):
     logging.disable(logging.CRITICAL)
     self.embedder_server = ImageEmbedder(
         model='inception-v3',
         layer='penultimate',
         server_url='example.com',
     )
     self.embedder_server.clear_cache()
     self.embedder_local = ImageEmbedder(
         model='squeezenet',
         layer='penultimate',
         server_url='example.com',
     )
     self.embedder_local.clear_cache()
     self.single_example = [_EXAMPLE_IMAGE_JPG]
Esempio n. 18
0
    def test_persistent_caching(self):
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server(self.single_example)
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server._embedder._cache.persist_cache()
        self.embedder_server = ImageEmbedder(model='inception-v3')
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server.clear_cache()
        self.embedder_server = ImageEmbedder(model='inception-v3')
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 0)
    def set_data(self, data):
        if not data:
            self._input_data = None
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        self.input_data_info.setText(
            "Data with {:d} instances.".format(len(data)))

        self._cb_image_attr_changed()
Esempio n. 20
0
    def set_data(self, data):
        if not data:
            self._input_data = None
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        self.input_data_info.setText(
            "Data with {:d} instances.".format(len(data)))

        self._cb_image_attr_changed()
    def test_invalid_layer(self):
        # test server embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

        # test local embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='squeezenet',
                layer='first',
                server_url='example.com',
            )
 def test_invalid_model(self):
     with self.assertRaises(ValueError):
         self.embedder_server = ImageEmbedder(
             model='invalid_model',
             layer='penultimate',
             server_url='example.com',
         )
Esempio n. 23
0
    def test_invalid_layer(self):
        # test server embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

        # test local embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='squeezenet',
                layer='first',
                server_url='example.com',
            )
Esempio n. 24
0
def main(argv=None):
    import sys
    from orangecontrib.imageanalytics.import_images import ImportImages
    from orangecontrib.imageanalytics.image_embedder import ImageEmbedder

    if argv is None:
        argv = sys.argv

    argv = list(argv)
    app = QApplication(argv)

    if len(argv) > 1:
        image_dir = argv[1]
    else:
        raise ValueError("Provide the image directory as the first argument.")

    import_images = ImportImages()
    images, err = import_images(image_dir)

    image_embedder = ImageEmbedder()
    embeddings, _, _ = image_embedder(images)

    ow = OWImageGrid()
    ow.show()
    ow.raise_()
    ow.set_data(Orange.data.Table(embeddings))
    rval = app.exec()

    ow.saveSettings()
    ow.onDeleteWidget()

    return rval
Esempio n. 25
0
    def set_data(self, data):
        self.Warning.clear()
        self.set_input_data_summary(data)
        self.clear_outputs()

        if not data:
            self._input_data = None
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        if not self._image_attributes:
            self._input_data = None
            self.Warning.no_image_attribute()
            self.clear_outputs()
            return

        self._input_data = data
        self._previous_attr_id = self.cb_image_attr_current_id
        self._previous_embedder_id = self.cb_embedder_current_id

        self.unconditional_commit()
 def test_invalid_layer(self):
     with self.assertRaises(ValueError):
         self.embedder = ImageEmbedder(
             model='inception-v3',
             layer='first',
             server_url='example.com',
             server_port=80
         )
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder._embedder._cache.persist_cache()

        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server._embedder._cache.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder.clear_cache()
Esempio n. 28
0
 def _send_output_signals(self, embeddings):
     embedded_images, skipped_images, num_skipped =\
         ImageEmbedder.prepare_output_data(self._input_data, embeddings)
     self.send(_Output.SKIPPED_IMAGES, skipped_images)
     self.send(_Output.EMBEDDINGS, embedded_images)
     if num_skipped is not 0:
         self.input_data_info.setText(
             "Data with {:d} instances, {:d} images skipped.".format(
                 len(self._input_data), num_skipped))
Esempio n. 29
0
    def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            np.testing.assert_array_equal(embedder(self.single_example),
                                          [[0, 1]])

        self.embedder_server = ImageEmbedder(model='inception-v3')
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            self.assertTupleEqual((1, 1000),
                                  embedder(self.single_example).shape)

        self.embedder_local = ImageEmbedder(model='squeezenet')
        self.assertEqual(len(self.embedder_local._embedder._cache._cache_dict),
                         1)
Esempio n. 30
0
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder(self.single_example)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.persist_cache()
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.clear_cache()
        self.embedder = ImageEmbedder(model='inception-v3',
                                      layer='penultimate',
                                      server_url='example.com',
                                      server_port=80)
        self.assertEqual(len(self.embedder._cache_dict), 0)
 def _send_output_signals(self, embeddings):
     embedded_images, skipped_images, num_skipped =\
         ImageEmbedder.prepare_output_data(self._input_data, embeddings)
     self.Outputs.embeddings.send(embedded_images)
     self.Outputs.skipped_images.send(skipped_images)
     if num_skipped is not 0:
         self.input_data_info.setText(
             "Data with {:d} instances, {:d} images skipped.".format(
                 len(self._input_data), num_skipped))
 def setUp(self):
     self.embedder = ImageEmbedder(
         model='inception-v3',
         layer='penultimate',
         server_url='example.com',
         server_port=80
     )
     self.embedder.clear_cache()
     self.single_example = [_EXAMPLE_IMAGE_0]
Esempio n. 33
0
 def _send_output_signals(self, embeddings):
     embedded_images, skipped_images, num_skipped =\
         ImageEmbedder.prepare_output_data(self._input_data, embeddings)
     self.Outputs.embeddings.send(embedded_images)
     self.Outputs.skipped_images.send(skipped_images)
     if num_skipped is not 0:
         self.input_data_info.setText(
             "Data with {:d} instances, {:d} images skipped.".format(
                 len(self._input_data), num_skipped))
 def _init_server_connection(self):
     self.setBlocking(False)
     self._image_embedder = ImageEmbedder(
         model=self.embedders[self.cb_embedder_current_id],
         layer='penultimate'
     )
     self._set_server_info(
         self._image_embedder.is_connected_to_server()
     )
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server(self.single_example)
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server._embedder._cache.persist_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server.clear_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder_server = ImageEmbedder(model="inception-v3", )
        self.embedder_server.clear_cache()
        self.embedder_local = ImageEmbedder(model="squeezenet", )
        self.embedder_local.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

        str_var = StringVariable("Image")
        str_var.attributes["origin"] = path.dirname(path.abspath(__file__))
        self.data_table = Table.from_numpy(
            Domain([], [], metas=[str_var]),
            np.empty((3, 0)),
            np.empty((3, 0)),
            metas=np.array([
                [_EXAMPLE_IMAGE_JPG],
                [_EXAMPLE_IMAGE_TIFF],
                [_EXAMPLE_IMAGE_GRAYSCALE],
            ]),
        )
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder_server._embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder_server._embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']
Esempio n. 38
0
def get_embeddings(dataset_number, image_file_paths, use_cached=True):
    file_name = '../data/saved_embeddings/' + datasets[dataset_number] + '.npy'

    # read from file if it exists to save time
    if os.path.isfile(file_name) and use_cached:
        return np.load(file_name)
    else:
        with ImageEmbedder(model='inception-v3',
                           layer='penultimate') as embedder:
            embeddings = embedder(image_file_paths)
        np.save(file_name, embeddings)
        return embeddings
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)
    def connect(self):
        """
        This function tries to connects to the selected embedder if it is not
        successful due to any server/connection error it switches to the
        local embedder and warns the user about that.
        """
        self.Warning.switched_local_embedder.clear()

        # try to connect to current embedder
        embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate'
        )

        if not embedder.is_local_embedder() and \
            not embedder.is_connected_to_server(use_hyper=False):
            # there is a problem with connecting to the server
            # switching to local embedder
            self.Warning.switched_local_embedder()
            del embedder  # remove current embedder
            self.cb_embedder_current_id = self.embedders.index("squeezenet")
            print(self.embedders[self.cb_embedder_current_id])
            embedder = ImageEmbedder(
                model=self.embedders[self.cb_embedder_current_id],
                layer='penultimate'
            )

        return embedder
Esempio n. 41
0
def run_embeddings(model, image_dir):
    image_file_paths = [
        os.path.join(dr, ff) for dr, _, ffs in os.walk(image_dir) for ff in ffs
        if ff not in ["index.html", "README.md"]
    ]
    print(image_file_paths)

    print("#of pics: {0}".format(len(image_file_paths)))

    with ImageEmbedder(model=model, layer='penultimate') as embedder:
        embedder.clear_cache()
        embeddings = embedder(image_file_paths)
        print(embeddings)
Esempio n. 42
0
 def dataset(self, data):
     """
     When new data receives update statuses and error messages. If aut_save
     also save them.
     """
     self.Error.clear()
     self.data = data
     self._update_status()
     self._update_messages()
     self.image_attributes = ImageEmbedder.filter_image_attributes(data) if \
         data is not None else []
     self._update_image_attributes()
     if self.auto_save and self.dirname:
         self.save_file()
    def __init__(self):
        super().__init__()
        self._image_attributes = None
        self._input_data = None

        self._setup_layout()

        self._image_embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )
 def _cb_embedder_changed(self):
     current_embedder = self.embedders[self.cb_embedder_current_id]
     self._image_embedder = ImageEmbedder(
         model=current_embedder,
         layer='penultimate'
     )
     self.embedder_info.setText(
         EMBEDDERS_INFO[current_embedder]['description'])
     if self._input_data:
         self.input_data_info.setText(
             "Data with {:d} instances.".format(len(self._input_data)))
         self.commit()
     else:
         self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
     self._set_server_info(self._image_embedder.is_connected_to_server())
Esempio n. 45
0
def get_result(image_name):
    print(image_name)
    model = pickle.load(open('oil2.pkcls', 'rb'))
    print("model loaded")
    image_file_paths = [image_name]
    print(image_file_paths)
    with ImageEmbedder(model='vgg16', layer='penultimate') as embedder:
        embeddings = embedder(image_file_paths)
    print(embeddings.shape)
    print("embedded")
    pred_ind = model(embeddings)
    pred_cls = [model.domain.class_var.str_val(i) for i in pred_ind]
    print(pred_ind)
    print(pred_cls)
    k.clear_session()
    return pred_cls
class OWImageEmbedding(OWWidget):
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    class Inputs:
        images = Input('Images', Table)

    class Outputs:
        embeddings = Output('Embeddings', Table, default=True)
        skipped_images = Output('Skipped Images', Table)

    cb_image_attr_current_id = Setting(default=0)
    cb_embedder_current_id = Setting(default=0)

    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self.embedders = sorted(list(EMBEDDERS_INFO),
                                key=lambda k: EMBEDDERS_INFO[k]['order'])
        self._image_attributes = None
        self._input_data = None
        self._log = logging.getLogger(__name__)
        self._task = None
        self._setup_layout()
        self._image_embedder = None
        self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        self.setBlocking(True)
        QTimer.singleShot(0, self._init_server_connection)

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(
            widget=widget_box,
            master=self,
            value='cb_image_attr_current_id',
            label='Image attribute:',
            orientation=Qt.Horizontal,
            callback=self._cb_image_attr_changed
        )

        self.cb_embedder = comboBox(
            widget=widget_box,
            master=self,
            value='cb_embedder_current_id',
            label='Embedder:',
            orientation=Qt.Horizontal,
            callback=self._cb_embedder_changed
        )
        names = [EMBEDDERS_INFO[e]['name'] +
                 (" (local)" if EMBEDDERS_INFO[e].get("is_local") else "")
                 for e in self.embedders]
        self.cb_embedder.setModel(VariableListModel(names))
        if not self.cb_embedder_current_id < len(self.embedders):
            self.cb_embedder_current_id = 0
        self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id)

        current_embedder = self.embedders[self.cb_embedder_current_id]
        self.embedder_info = widgetLabel(
            widget_box,
            EMBEDDERS_INFO[current_embedder]['description']
        )

        self.auto_commit_widget = auto_commit(
            widget=self.controlArea,
            master=self,
            value='_auto_apply',
            label='Apply',
            commit=self.commit
        )

        self.cancel_button = QPushButton(
            'Cancel',
            icon=self.style().standardIcon(QStyle.SP_DialogCancelButton),
        )
        self.cancel_button.clicked.connect(self.cancel)
        hbox = hBox(self.controlArea)
        hbox.layout().addWidget(self.cancel_button)
        self.cancel_button.setDisabled(True)

    def _init_server_connection(self):
        self.setBlocking(False)
        self._image_embedder = ImageEmbedder(
            model=self.embedders[self.cb_embedder_current_id],
            layer='penultimate'
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )

    @Inputs.images
    def set_data(self, data):
        if not data:
            self._input_data = None
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = ImageEmbedder.filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        self.input_data_info.setText(
            "Data with {:d} instances.".format(len(data)))

        self._cb_image_attr_changed()

    def _cb_image_attr_changed(self):
        self.commit()

    def _cb_embedder_changed(self):
        current_embedder = self.embedders[self.cb_embedder_current_id]
        self._image_embedder = ImageEmbedder(
            model=current_embedder,
            layer='penultimate'
        )
        self.embedder_info.setText(
            EMBEDDERS_INFO[current_embedder]['description'])
        if self._input_data:
            self.input_data_info.setText(
                "Data with {:d} instances.".format(len(self._input_data)))
            self.commit()
        else:
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
        self._set_server_info(self._image_embedder.is_connected_to_server())

    def commit(self):
        if self._task is not None:
            self.cancel()

        if self._image_embedder is None:
            self._set_server_info(connected=False)
            return

        if not self._image_attributes or self._input_data is None:
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        self._set_server_info(connected=True)
        self.cancel_button.setDisabled(False)
        self.cb_image_attr.setDisabled(True)
        self.cb_embedder.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()
        origin = file_paths_attr.attributes.get("origin", "")
        if urlparse(origin).scheme in ("http", "https", "ftp", "data") and \
                origin[-1] != "/":
            origin += "/"

        assert file_paths_attr.is_string
        assert file_paths.dtype == np.dtype('O')

        file_paths_mask = file_paths == file_paths_attr.Unknown
        file_paths_valid = file_paths[~file_paths_mask]
        for i, a in enumerate(file_paths_valid):
            urlparts = urlparse(a)
            if urlparts.scheme not in ("http", "https", "ftp", "data"):
                if urlparse(origin).scheme in ("http", "https", "ftp", "data"):
                    file_paths_valid[i] = urljoin(origin, a)
                else:
                    file_paths_valid[i] = os.path.join(origin, a)

        ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size))
        set_progress = qconcurrent.methodinvoke(
            self, "__progress_set", (float,))

        def advance(success=True):
            if success:
                set_progress(next(ticks))

        def cancel():
            task.future.cancel()
            task.cancelled = True
            task.embedder.set_canceled(True)

        embedder = self._image_embedder

        def run_embedding(paths):
            return embedder(
                file_paths=paths, image_processed_callback=advance)

        self.auto_commit_widget.setDisabled(True)
        self.progressBarInit(processEvents=None)
        self.progressBarSet(0.0, processEvents=None)
        self.setBlocking(True)

        f = self._executor.submit(run_embedding, file_paths_valid)
        f.add_done_callback(
            qconcurrent.methodinvoke(self, "__set_results", (object,)))

        task = self._task = namespace(
            file_paths_mask=file_paths_mask,
            file_paths_valid=file_paths_valid,
            file_paths=file_paths,
            embedder=embedder,
            cancelled=False,
            cancel=cancel,
            future=f,
        )
        self._log.debug("Starting embedding task for %i images",
                        file_paths.size)
        return

    @Slot(float)
    def __progress_set(self, value):
        assert self.thread() is QThread.currentThread()
        if self._task is not None:
            self.progressBarSet(value)

    @Slot(object)
    def __set_results(self, f):
        assert self.thread() is QThread.currentThread()
        if self._task is None or self._task.future is not f:
            self._log.info("Reaping stale task")
            return

        assert f.done()

        task, self._task = self._task, None
        self.auto_commit_widget.setDisabled(False)
        self.cancel_button.setDisabled(True)
        self.cb_image_attr.setDisabled(False)
        self.cb_embedder.setDisabled(False)
        self.progressBarFinished(processEvents=None)
        self.setBlocking(False)

        try:
            embeddings = f.result()
        except ConnectionError:
            self._log.exception("Error", exc_info=True)
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            self._set_server_info(connected=False)
            return
        except Exception as err:
            self._log.exception("Error", exc_info=True)
            self.error(
                "\n".join(traceback.format_exception_only(type(err), err)))
            self.Outputs.embeddings.send(None)
            self.Outputs.skipped_images.send(None)
            return

        assert self._input_data is not None
        assert len(self._input_data) == len(task.file_paths_mask)

        # Missing paths/urls were filtered out. Restore the full embeddings
        # array from information stored in task.file_path_mask ...
        embeddings_all = [None] * len(task.file_paths_mask)
        for i, embedding in zip(np.flatnonzero(~task.file_paths_mask),
                                embeddings):
            embeddings_all[i] = embedding
        embeddings_all = np.array(embeddings_all)
        self._send_output_signals(embeddings_all)

    def _send_output_signals(self, embeddings):
        embedded_images, skipped_images, num_skipped =\
            ImageEmbedder.prepare_output_data(self._input_data, embeddings)
        self.Outputs.embeddings.send(embedded_images)
        self.Outputs.skipped_images.send(skipped_images)
        if num_skipped is not 0:
            self.input_data_info.setText(
                "Data with {:d} instances, {:d} images skipped.".format(
                    len(self._input_data), num_skipped))

    def _set_server_info(self, connected):
        self.clear_messages()
        if self._image_embedder is None:
            return

        if connected:
            self.connection_info.setText("Connected to server.")
        elif self._image_embedder.is_local_embedder():
            self.connection_info.setText("Using local embedder.")
        else:
            self.connection_info.setText("Not connected to server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        self.cancel()
        super().onDeleteWidget()
        if self._image_embedder is not None:
            self._image_embedder.__exit__(None, None, None)

    def cancel(self):
        if self._task is not None:
            task, self._task = self._task, None
            task.cancel()
            # wait until done
            try:
                task.future.exception()
            except qconcurrent.CancelledError:
                pass

            self.auto_commit_widget.setDisabled(False)
            self.cancel_button.setDisabled(True)
            self.progressBarFinished(processEvents=None)
            self.setBlocking(False)
            self.cb_image_attr.setDisabled(False)
            self.cb_embedder.setDisabled(False)
            self._image_embedder.set_canceled(False)
            # reset the connection.
            connected = self._image_embedder.reconnect_to_server()
            self._set_server_info(connected=connected)
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.embedder.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_0]

    def tearDown(self):
        self.embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        # server closes the connection
        self.embedder._server_connection.close()
        self.assertEqual(self.embedder.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder.clear_cache()

        self.embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder(self.single_example)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder.reconnect_to_server()
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder.is_connected_to_server(), True)
        self.embedder.disconnect_from_server()
        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder(self.single_example)
        self.assertEqual(self.embedder.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder(self.single_example), [None])
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder._cache_dict), 0)
        self.embedder(self.single_example)
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.persist_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

        self.embedder.clear_cache()
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        with self.embedder as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder.is_connected_to_server(), False)
        self.embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
            server_port=80
        )
        self.assertEqual(len(self.embedder._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_0 for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder(too_many_examples)
        assert_array_equal(res, true_res)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        more_examples = [_EXAMPLE_IMAGE_0 for _ in range(5)]
        res = self.embedder(more_examples)
        self.assertEqual(res.shape, (5, 2))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
                server_port=80
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        with self.assertRaises(ValueError):
            self.embedder = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
                server_port=80
            )
class OWImageEmbedding(OWWidget):
    # todo: implement embedding in a non-blocking manner
    # todo: implement stop running task action
    name = "Image Embedding"
    description = "Image embedding through deep neural networks."
    icon = "icons/ImageEmbedding.svg"
    priority = 150

    want_main_area = False
    _auto_apply = Setting(default=True)

    inputs = [(_Input.IMAGES, Table, 'set_data')]
    outputs = [
        (_Output.EMBEDDINGS, Table, Default),
        (_Output.SKIPPED_IMAGES, Table)
    ]

    cb_image_attr_current_id = Setting(default=0)
    _NO_DATA_INFO_TEXT = "No data on input."

    def __init__(self):
        super().__init__()
        self._image_attributes = None
        self._input_data = None

        self._setup_layout()

        self._image_embedder = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
        )
        self._set_server_info(
            self._image_embedder.is_connected_to_server()
        )

    def _setup_layout(self):
        self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width())
        self.layout().setSizeConstraint(QLayout.SetFixedSize)

        widget_box = widgetBox(self.controlArea, 'Info')
        self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT)
        self.connection_info = widgetLabel(widget_box, "")

        widget_box = widgetBox(self.controlArea, 'Settings')
        self.cb_image_attr = comboBox(
            widget=widget_box,
            master=self,
            value='cb_image_attr_current_id',
            label='Image attribute:',
            orientation=Qt.Horizontal,
            callback=self._cb_image_attr_changed
        )

        self.auto_commit_widget = auto_commit(
            widget=self.controlArea,
            master=self,
            value='_auto_apply',
            label='Apply',
            checkbox_label='Auto Apply',
            commit=self.commit
        )

    def set_data(self, data):
        if data is None:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            self.input_data_info.setText(self._NO_DATA_INFO_TEXT)
            return

        self._image_attributes = self._filter_image_attributes(data)
        if not self._image_attributes:
            input_data_info_text = (
                "Data with {:d} instances, but without image attributes."
                .format(len(data)))
            input_data_info_text.format(input_data_info_text)
            self.input_data_info.setText(input_data_info_text)
            self._input_data = None
            return

        if not self.cb_image_attr_current_id < len(self._image_attributes):
            self.cb_image_attr_current_id = 0

        self.cb_image_attr.setModel(VariableListModel(self._image_attributes))
        self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id)

        self._input_data = data
        input_data_info_text = "Data with {:d} instances.".format(len(data))
        self.input_data_info.setText(input_data_info_text)

        self._cb_image_attr_changed()

    @staticmethod
    def _filter_image_attributes(data):
        metas = data.domain.metas
        return [m for m in metas if m.attributes.get('type') == 'image']

    def _cb_image_attr_changed(self):
        if self._auto_apply:
            self.commit()

    def commit(self):
        if not self._image_attributes or not self._input_data:
            self.send(_Output.EMBEDDINGS, None)
            self.send(_Output.SKIPPED_IMAGES, None)
            return

        self._set_server_info(connected=True)
        self.auto_commit_widget.setDisabled(True)

        file_paths_attr = self._image_attributes[self.cb_image_attr_current_id]
        file_paths = self._input_data[:, file_paths_attr].metas.flatten()

        with self.progressBar(len(file_paths)) as progress:
            try:
                embeddings = self._image_embedder(
                    file_paths=file_paths,
                    image_processed_callback=lambda: progress.advance()
                )
            except ConnectionError:
                self.send(_Output.EMBEDDINGS, None)
                self.send(_Output.SKIPPED_IMAGES, None)
                self._set_server_info(connected=False)
                self.auto_commit_widget.setDisabled(False)
                return

        self._send_output_signals(embeddings)
        self.auto_commit_widget.setDisabled(False)

    def _send_output_signals(self, embeddings):
        skipped_images_bool = np.array([x is None for x in embeddings])

        if np.any(skipped_images_bool):
            skipped_images = self._input_data[skipped_images_bool]
            skipped_images = Table(skipped_images)
            skipped_images.ids = self._input_data.ids[skipped_images_bool]
            self.send(_Output.SKIPPED_IMAGES, skipped_images)
        else:
            self.send(_Output.SKIPPED_IMAGES, None)

        embedded_images_bool = np.logical_not(skipped_images_bool)

        if np.any(embedded_images_bool):
            embedded_images = self._input_data[embedded_images_bool]

            embeddings = embeddings[embedded_images_bool]
            embeddings = np.stack(embeddings)

            embedded_images = self._construct_output_data_table(
                embedded_images,
                embeddings
            )
            embedded_images.ids = self._input_data.ids[embedded_images_bool]
            self.send(_Output.EMBEDDINGS, embedded_images)
        else:
            self.send(_Output.EMBEDDINGS, None)

    @staticmethod
    def _construct_output_data_table(embedded_images, embeddings):
        X = np.hstack((embedded_images.X, embeddings))
        Y = embedded_images.Y

        dimensions = range(embeddings.shape[1])
        attributes = [ContinuousVariable('n{:d}'.format(d)) for d in dimensions]
        attributes = list(embedded_images.domain.attributes) + attributes

        domain = Domain(
            attributes=attributes,
            class_vars=embedded_images.domain.class_vars,
            metas=embedded_images.domain.metas
        )

        return Table(domain, X, Y, embedded_images.metas)

    def _set_server_info(self, connected):
        self.clear_messages()
        if connected:
            self.connection_info.setText("Connected to server.")
        else:
            self.connection_info.setText("No connection with server.")
            self.warning("Click Apply to try again.")

    def onDeleteWidget(self):
        super().onDeleteWidget()
        self._image_embedder.__exit__(None, None, None)
class ImageEmbedderTest(unittest.TestCase):
    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def setUp(self):
        logging.disable(logging.CRITICAL)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_server.clear_cache()
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.embedder_local.clear_cache()
        self.single_example = [_EXAMPLE_IMAGE_JPG]

    def tearDown(self):
        self.embedder_server.clear_cache()
        logging.disable(logging.NOTSET)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connected_to_server(self, ConnectionMock):
        ConnectionMock._discover_server.assert_not_called()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        # server closes the connection
        self.embedder_server._embedder._server_connection.close()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_connection_errors(self, ConnectionMock):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.embedder_server.clear_cache()

        self.embedder_server._embedder._server_connection.close()
        ConnectionMock.side_effect = ConnectionRefusedError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

        ConnectionMock.side_effect = BrokenPipeError
        with self.assertRaises(ConnectionError):
            self.embedder_server(self.single_example)

    @patch.object(DummyHttp2Connection, 'get_response')
    def test_on_stream_reset_by_server(self, ConnectionMock):
        ConnectionMock.side_effect = StreamResetError
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_disconnect_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server.reconnect_to_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_auto_reconnect(self):
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)
        self.embedder_server._embedder.disconnect_from_server()
        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server(self.single_example)
        self.assertEqual(self.embedder_server.is_connected_to_server(), True)

    @patch(_TESTED_MODULE.format('HTTP20Connection'))
    def test_with_non_existing_image(self, ConnectionMock):
        self.single_example = ['/non_existing_image']

        self.assertEqual(self.embedder_server(self.single_example), [None])
        ConnectionMock.request.assert_not_called()
        ConnectionMock.get_response.assert_not_called()
        self.assertEqual(self.embedder_server._embedder._cache._cache_dict, {})

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_on_successful_response(self):
        res = self.embedder_server(self.single_example)
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(b''))
    def test_on_non_json_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch.object(
        DummyHttp2Connection, 'get_response',
        lambda self, _: BytesIO(json.dumps({'wrong_key': None}).encode()))
    def test_on_json_wrong_key_response(self):
        self.assertEqual(self.embedder_server(self.single_example), [None])
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_persistent_caching(self):
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server(self.single_example)
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server._embedder._cache.persist_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        self.embedder_server.clear_cache()
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_different_models_caches(self):
        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        embedder.clear_cache()
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
        embedder(self.single_example)
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder._embedder._cache.persist_cache()

        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 0)
        self.embedder_server._embedder._cache.persist_cache()

        embedder = ImageEmbedder(
            model='painters',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
        embedder.clear_cache()

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_statement(self):
        # server embedder
        with self.embedder_server as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_server.is_connected_to_server(), False)
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # local embedder
        with self.embedder_local as embedder:
            embedder(self.single_example)

        self.assertEqual(self.embedder_local.is_connected_to_server(), False)
        self.embedder_local = ImageEmbedder(
            model='squeezenet',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_max_concurrent_streams_setting(self):
        self.assertEqual(self.embedder_server._embedder._max_concurrent_streams, 128)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_too_many_examples_for_one_batch(self):
        too_many_examples = [_EXAMPLE_IMAGE_JPG for _ in range(200)]
        true_res = [np.array([0, 1], dtype=np.float16) for _ in range(200)]
        true_res = np.array(true_res)

        res = self.embedder_server(too_many_examples)
        assert_array_equal(res, true_res)
        # no need to test it on local embedder since it does not work
        # in batches

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_successful_result_shape(self):
        # global embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_server(more_examples)
        self.assertEqual(res.shape, (5, 2))

        # local embedder
        more_examples = [_EXAMPLE_IMAGE_JPG for _ in range(5)]
        res = self.embedder_local(more_examples)
        self.assertEqual(res.shape, (5, 1000))

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_model(self):
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='invalid_model',
                layer='penultimate',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_invalid_layer(self):
        # test server embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='inception-v3',
                layer='first',
                server_url='example.com',
            )

        # test local embedder
        with self.assertRaises(ValueError):
            self.embedder_server = ImageEmbedder(
                model='squeezenet',
                layer='first',
                server_url='example.com',
            )

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_grayscale_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_GRAYSCALE])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_GRAYSCALE])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_with_tiff_image(self):
        # test server embedder
        res = self.embedder_server([_EXAMPLE_IMAGE_TIFF])
        assert_array_equal(res, np.array([np.array([0, 1], dtype=np.float16)]))
        self.assertEqual(
            len(self.embedder_server._embedder._cache._cache_dict), 1)

        # test local embedder
        res = self.embedder_local([_EXAMPLE_IMAGE_TIFF])
        self.assertEqual(
            len(self.embedder_local._embedder._cache._cache_dict), 1)

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_server_url_env_var(self):
        url_value = 'url:1234'
        self.assertTrue(self.embedder_server._embedder._server_url != url_value)

        environ['ORANGE_EMBEDDING_API_URL'] = url_value
        self.embedder_server = ImageEmbedder(
            model='inception-v3',
            layer='penultimate',
            server_url='example.com',
        )
        self.assertTrue(self.embedder_server._embedder._server_url == url_value)
        del environ['ORANGE_EMBEDDING_API_URL']

    @patch(_TESTED_MODULE.format('HTTP20Connection'), DummyHttp2Connection)
    def test_embedding_cancelled(self):
        # test for the server embedders
        self.assertFalse(self.embedder_server._embedder.cancelled)
        self.embedder_server._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_server(self.single_example)

        # test for the local embedder
        self.assertFalse(self.embedder_local._embedder.cancelled)
        self.embedder_local._embedder.cancelled = True
        with self.assertRaises(Exception):
            self.embedder_local(self.single_example)