class OWImageEmbedding(OWWidget): name = "Image Embedding" description = "Image embedding through deep neural networks." icon = "icons/ImageEmbedding.svg" priority = 150 want_main_area = False _auto_apply = Setting(default=True) class Inputs: images = Input('Images', Table) class Outputs: embeddings = Output('Embeddings', Table, default=True) skipped_images = Output('Skipped Images', Table) cb_image_attr_current_id = Setting(default=0) cb_embedder_current_id = Setting(default=0) _NO_DATA_INFO_TEXT = "No data on input." def __init__(self): super().__init__() self.embedders = sorted(list(EMBEDDERS_INFO), key=lambda k: EMBEDDERS_INFO[k]['order']) self._image_attributes = None self._input_data = None self._log = logging.getLogger(__name__) self._task = None self._setup_layout() self._image_embedder = None self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) self.setBlocking(True) QTimer.singleShot(0, self._init_server_connection) def _setup_layout(self): self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width()) self.layout().setSizeConstraint(QLayout.SetFixedSize) widget_box = widgetBox(self.controlArea, 'Info') self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT) self.connection_info = widgetLabel(widget_box, "") widget_box = widgetBox(self.controlArea, 'Settings') self.cb_image_attr = comboBox( widget=widget_box, master=self, value='cb_image_attr_current_id', label='Image attribute:', orientation=Qt.Horizontal, callback=self._cb_image_attr_changed ) self.cb_embedder = comboBox( widget=widget_box, master=self, value='cb_embedder_current_id', label='Embedder:', orientation=Qt.Horizontal, callback=self._cb_embedder_changed ) names = [EMBEDDERS_INFO[e]['name'] + (" (local)" if EMBEDDERS_INFO[e].get("is_local") else "") for e in self.embedders] self.cb_embedder.setModel(VariableListModel(names)) if not self.cb_embedder_current_id < len(self.embedders): self.cb_embedder_current_id = 0 self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id) current_embedder = self.embedders[self.cb_embedder_current_id] self.embedder_info = widgetLabel( widget_box, EMBEDDERS_INFO[current_embedder]['description'] ) self.auto_commit_widget = auto_commit( widget=self.controlArea, master=self, value='_auto_apply', label='Apply', commit=self.commit ) self.cancel_button = QPushButton( 'Cancel', icon=self.style().standardIcon(QStyle.SP_DialogCancelButton), ) self.cancel_button.clicked.connect(self.cancel) hbox = hBox(self.controlArea) hbox.layout().addWidget(self.cancel_button) self.cancel_button.setDisabled(True) def _init_server_connection(self): self.setBlocking(False) self._image_embedder = ImageEmbedder( model=self.embedders[self.cb_embedder_current_id], layer='penultimate' ) self._set_server_info( self._image_embedder.is_connected_to_server() ) @Inputs.images def set_data(self, data): if not data: self._input_data = None self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) self.input_data_info.setText(self._NO_DATA_INFO_TEXT) return self._image_attributes = ImageEmbedder.filter_image_attributes(data) if not self._image_attributes: input_data_info_text = ( "Data with {:d} instances, but without image attributes." .format(len(data))) input_data_info_text.format(input_data_info_text) self.input_data_info.setText(input_data_info_text) self._input_data = None return if not self.cb_image_attr_current_id < len(self._image_attributes): self.cb_image_attr_current_id = 0 self.cb_image_attr.setModel(VariableListModel(self._image_attributes)) self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id) self._input_data = data self.input_data_info.setText( "Data with {:d} instances.".format(len(data))) self._cb_image_attr_changed() def _cb_image_attr_changed(self): self.commit() def _cb_embedder_changed(self): current_embedder = self.embedders[self.cb_embedder_current_id] self._image_embedder = ImageEmbedder( model=current_embedder, layer='penultimate' ) self.embedder_info.setText( EMBEDDERS_INFO[current_embedder]['description']) if self._input_data: self.input_data_info.setText( "Data with {:d} instances.".format(len(self._input_data))) self.commit() else: self.input_data_info.setText(self._NO_DATA_INFO_TEXT) self._set_server_info(self._image_embedder.is_connected_to_server()) def commit(self): if self._task is not None: self.cancel() if self._image_embedder is None: self._set_server_info(connected=False) return if not self._image_attributes or self._input_data is None: self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) return self._set_server_info(connected=True) self.cancel_button.setDisabled(False) self.cb_image_attr.setDisabled(True) self.cb_embedder.setDisabled(True) file_paths_attr = self._image_attributes[self.cb_image_attr_current_id] file_paths = self._input_data[:, file_paths_attr].metas.flatten() origin = file_paths_attr.attributes.get("origin", "") if urlparse(origin).scheme in ("http", "https", "ftp", "data") and \ origin[-1] != "/": origin += "/" assert file_paths_attr.is_string assert file_paths.dtype == np.dtype('O') file_paths_mask = file_paths == file_paths_attr.Unknown file_paths_valid = file_paths[~file_paths_mask] for i, a in enumerate(file_paths_valid): urlparts = urlparse(a) if urlparts.scheme not in ("http", "https", "ftp", "data"): if urlparse(origin).scheme in ("http", "https", "ftp", "data"): file_paths_valid[i] = urljoin(origin, a) else: file_paths_valid[i] = os.path.join(origin, a) ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size)) set_progress = qconcurrent.methodinvoke( self, "__progress_set", (float,)) def advance(success=True): if success: set_progress(next(ticks)) def cancel(): task.future.cancel() task.cancelled = True task.embedder.set_canceled(True) embedder = self._image_embedder def run_embedding(paths): return embedder( file_paths=paths, image_processed_callback=advance) self.auto_commit_widget.setDisabled(True) self.progressBarInit(processEvents=None) self.progressBarSet(0.0, processEvents=None) self.setBlocking(True) f = self._executor.submit(run_embedding, file_paths_valid) f.add_done_callback( qconcurrent.methodinvoke(self, "__set_results", (object,))) task = self._task = namespace( file_paths_mask=file_paths_mask, file_paths_valid=file_paths_valid, file_paths=file_paths, embedder=embedder, cancelled=False, cancel=cancel, future=f, ) self._log.debug("Starting embedding task for %i images", file_paths.size) return @Slot(float) def __progress_set(self, value): assert self.thread() is QThread.currentThread() if self._task is not None: self.progressBarSet(value) @Slot(object) def __set_results(self, f): assert self.thread() is QThread.currentThread() if self._task is None or self._task.future is not f: self._log.info("Reaping stale task") return assert f.done() task, self._task = self._task, None self.auto_commit_widget.setDisabled(False) self.cancel_button.setDisabled(True) self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self.progressBarFinished(processEvents=None) self.setBlocking(False) try: embeddings = f.result() except ConnectionError: self._log.exception("Error", exc_info=True) self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) self._set_server_info(connected=False) return except Exception as err: self._log.exception("Error", exc_info=True) self.error( "\n".join(traceback.format_exception_only(type(err), err))) self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) return assert self._input_data is not None assert len(self._input_data) == len(task.file_paths_mask) # Missing paths/urls were filtered out. Restore the full embeddings # array from information stored in task.file_path_mask ... embeddings_all = [None] * len(task.file_paths_mask) for i, embedding in zip(np.flatnonzero(~task.file_paths_mask), embeddings): embeddings_all[i] = embedding embeddings_all = np.array(embeddings_all) self._send_output_signals(embeddings_all) def _send_output_signals(self, embeddings): embedded_images, skipped_images, num_skipped =\ ImageEmbedder.prepare_output_data(self._input_data, embeddings) self.Outputs.embeddings.send(embedded_images) self.Outputs.skipped_images.send(skipped_images) if num_skipped is not 0: self.input_data_info.setText( "Data with {:d} instances, {:d} images skipped.".format( len(self._input_data), num_skipped)) def _set_server_info(self, connected): self.clear_messages() if self._image_embedder is None: return if connected: self.connection_info.setText("Connected to server.") elif self._image_embedder.is_local_embedder(): self.connection_info.setText("Using local embedder.") else: self.connection_info.setText("Not connected to server.") self.warning("Click Apply to try again.") def onDeleteWidget(self): self.cancel() super().onDeleteWidget() if self._image_embedder is not None: self._image_embedder.__exit__(None, None, None) def cancel(self): if self._task is not None: task, self._task = self._task, None task.cancel() # wait until done try: task.future.exception() except qconcurrent.CancelledError: pass self.auto_commit_widget.setDisabled(False) self.cancel_button.setDisabled(True) self.progressBarFinished(processEvents=None) self.setBlocking(False) self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self._image_embedder.set_canceled(False) # reset the connection. connected = self._image_embedder.reconnect_to_server() self._set_server_info(connected=connected)
class OWImageEmbedding(OWWidget): name = "Image Embedding" description = "Image embedding through deep neural networks." icon = "icons/ImageEmbedding.svg" priority = 150 want_main_area = False _auto_apply = Setting(default=True) class Inputs: images = Input('Images', Table) class Outputs: embeddings = Output('Embeddings', Table, default=True) skipped_images = Output('Skipped Images', Table) cb_image_attr_current_id = Setting(default=0) cb_embedder_current_id = Setting(default=0) _NO_DATA_INFO_TEXT = "No data on input." def __init__(self): super().__init__() self.embedders = sorted(list(EMBEDDERS_INFO), key=lambda k: EMBEDDERS_INFO[k]['order']) self._image_attributes = None self._input_data = None self._log = logging.getLogger(__name__) self._task = None self._setup_layout() self._image_embedder = None self._executor = qconcurrent.ThreadExecutor( self, threadPool=QThreadPool(maxThreadCount=1) ) self.setBlocking(True) QTimer.singleShot(0, self._init_server_connection) def _setup_layout(self): self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width()) self.layout().setSizeConstraint(QLayout.SetFixedSize) widget_box = widgetBox(self.controlArea, 'Info') self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT) self.connection_info = widgetLabel(widget_box, "") widget_box = widgetBox(self.controlArea, 'Settings') self.cb_image_attr = comboBox( widget=widget_box, master=self, value='cb_image_attr_current_id', label='Image attribute:', orientation=Qt.Horizontal, callback=self._cb_image_attr_changed ) self.cb_embedder = comboBox( widget=widget_box, master=self, value='cb_embedder_current_id', label='Embedder:', orientation=Qt.Horizontal, callback=self._cb_embedder_changed ) self.cb_embedder.setModel(VariableListModel( [EMBEDDERS_INFO[e]['name'] for e in self.embedders])) if not self.cb_embedder_current_id < len(self.embedders): self.cb_embedder_current_id = 0 self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id) current_embedder = self.embedders[self.cb_embedder_current_id] self.embedder_info = widgetLabel( widget_box, EMBEDDERS_INFO[current_embedder]['description'] ) self.auto_commit_widget = auto_commit( widget=self.controlArea, master=self, value='_auto_apply', label='Apply', commit=self.commit ) self.cancel_button = QPushButton( 'Cancel', icon=self.style().standardIcon(QStyle.SP_DialogCancelButton), ) self.cancel_button.clicked.connect(self.cancel) hbox = hBox(self.controlArea) hbox.layout().addWidget(self.cancel_button) self.cancel_button.setDisabled(True) def _init_server_connection(self): self.setBlocking(False) self._image_embedder = ImageEmbedder( model=self.embedders[self.cb_embedder_current_id], layer='penultimate' ) self._set_server_info( self._image_embedder.is_connected_to_server() ) @Inputs.images def set_data(self, data): if not data: self._input_data = None self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) self.input_data_info.setText(self._NO_DATA_INFO_TEXT) return self._image_attributes = ImageEmbedder.filter_image_attributes(data) if not self._image_attributes: input_data_info_text = ( "Data with {:d} instances, but without image attributes." .format(len(data))) input_data_info_text.format(input_data_info_text) self.input_data_info.setText(input_data_info_text) self._input_data = None return if not self.cb_image_attr_current_id < len(self._image_attributes): self.cb_image_attr_current_id = 0 self.cb_image_attr.setModel(VariableListModel(self._image_attributes)) self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id) self._input_data = data self.input_data_info.setText( "Data with {:d} instances.".format(len(data))) self._cb_image_attr_changed() def _cb_image_attr_changed(self): self.commit() def _cb_embedder_changed(self): current_embedder = self.embedders[self.cb_embedder_current_id] self._image_embedder = ImageEmbedder( model=current_embedder, layer='penultimate' ) self.embedder_info.setText( EMBEDDERS_INFO[current_embedder]['description']) if self._input_data: self.input_data_info.setText( "Data with {:d} instances.".format(len(self._input_data))) self.commit() else: self.input_data_info.setText(self._NO_DATA_INFO_TEXT) def commit(self): if self._task is not None: self.cancel() if self._image_embedder is None: self._set_server_info(connected=False) return if not self._image_attributes or self._input_data is None: self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) return self._set_server_info(connected=True) self.cancel_button.setDisabled(False) self.cb_image_attr.setDisabled(True) self.cb_embedder.setDisabled(True) file_paths_attr = self._image_attributes[self.cb_image_attr_current_id] file_paths = self._input_data[:, file_paths_attr].metas.flatten() origin = file_paths_attr.attributes.get("origin", "") if urlparse(origin).scheme in ("http", "https", "ftp", "data") and \ origin[-1] != "/": origin += "/" assert file_paths_attr.is_string assert file_paths.dtype == np.dtype('O') file_paths_mask = file_paths == file_paths_attr.Unknown file_paths_valid = file_paths[~file_paths_mask] for i, a in enumerate(file_paths_valid): urlparts = urlparse(a) if urlparts.scheme not in ("http", "https", "ftp", "data"): if urlparse(origin).scheme in ("http", "https", "ftp", "data"): file_paths_valid[i] = urljoin(origin, a) else: file_paths_valid[i] = os.path.join(origin, a) ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size)) set_progress = qconcurrent.methodinvoke( self, "__progress_set", (float,)) def advance(success=True): if success: set_progress(next(ticks)) def cancel(): task.future.cancel() task.cancelled = True task.embedder.cancelled = True embedder = self._image_embedder def run_embedding(paths): return embedder( file_paths=paths, image_processed_callback=advance) self.auto_commit_widget.setDisabled(True) self.progressBarInit(processEvents=None) self.progressBarSet(0.0, processEvents=None) self.setBlocking(True) f = self._executor.submit(run_embedding, file_paths_valid) f.add_done_callback( qconcurrent.methodinvoke(self, "__set_results", (object,))) task = self._task = namespace( file_paths_mask=file_paths_mask, file_paths_valid=file_paths_valid, file_paths=file_paths, embedder=embedder, cancelled=False, cancel=cancel, future=f, ) self._log.debug("Starting embedding task for %i images", file_paths.size) return @Slot(float) def __progress_set(self, value): assert self.thread() is QThread.currentThread() if self._task is not None: self.progressBarSet(value) @Slot(object) def __set_results(self, f): assert self.thread() is QThread.currentThread() if self._task is None or self._task.future is not f: self._log.info("Reaping stale task") return assert f.done() task, self._task = self._task, None self.auto_commit_widget.setDisabled(False) self.cancel_button.setDisabled(True) self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self.progressBarFinished(processEvents=None) self.setBlocking(False) try: embeddings = f.result() except ConnectionError: self._log.exception("Error", exc_info=True) self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) self._set_server_info(connected=False) return except Exception as err: self._log.exception("Error", exc_info=True) self.error( "\n".join(traceback.format_exception_only(type(err), err))) self.Outputs.embeddings.send(None) self.Outputs.skipped_images.send(None) return assert self._input_data is not None assert len(self._input_data) == len(task.file_paths_mask) # Missing paths/urls were filtered out. Restore the full embeddings # array from information stored in task.file_path_mask ... embeddings_all = [None] * len(task.file_paths_mask) for i, embedding in zip(np.flatnonzero(~task.file_paths_mask), embeddings): embeddings_all[i] = embedding embeddings_all = np.array(embeddings_all) self._send_output_signals(embeddings_all) def _send_output_signals(self, embeddings): embedded_images, skipped_images, num_skipped =\ ImageEmbedder.prepare_output_data(self._input_data, embeddings) self.Outputs.embeddings.send(embedded_images) self.Outputs.skipped_images.send(skipped_images) if num_skipped is not 0: self.input_data_info.setText( "Data with {:d} instances, {:d} images skipped.".format( len(self._input_data), num_skipped)) def _set_server_info(self, connected): self.clear_messages() if connected: self.connection_info.setText("Connected to server.") else: self.connection_info.setText("No connection with server.") self.warning("Click Apply to try again.") def onDeleteWidget(self): self.cancel() super().onDeleteWidget() if self._image_embedder is not None: self._image_embedder.__exit__(None, None, None) def cancel(self): if self._task is not None: task, self._task = self._task, None task.cancel() # wait until done try: task.future.exception() except qconcurrent.CancelledError: pass self.auto_commit_widget.setDisabled(False) self.cancel_button.setDisabled(True) self.progressBarFinished(processEvents=None) self.setBlocking(False) self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self._image_embedder.cancelled = False # reset the connection. connected = self._image_embedder.reconnect_to_server() self._set_server_info(connected=connected)
class OWImageEmbedding(OWWidget): # todo: implement embedding in a non-blocking manner # todo: implement stop running task action name = "Image Embedding" description = "Image embedding through deep neural networks." icon = "icons/ImageEmbedding.svg" priority = 150 want_main_area = False _auto_apply = Setting(default=True) inputs = [(_Input.IMAGES, Table, 'set_data')] outputs = [ (_Output.EMBEDDINGS, Table, Default), (_Output.SKIPPED_IMAGES, Table) ] cb_image_attr_current_id = Setting(default=0) _NO_DATA_INFO_TEXT = "No data on input." def __init__(self): super().__init__() self._image_attributes = None self._input_data = None self._setup_layout() self._image_embedder = ImageEmbedder( model='inception-v3', layer='penultimate', ) self._set_server_info( self._image_embedder.is_connected_to_server() ) def _setup_layout(self): self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width()) self.layout().setSizeConstraint(QLayout.SetFixedSize) widget_box = widgetBox(self.controlArea, 'Info') self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT) self.connection_info = widgetLabel(widget_box, "") widget_box = widgetBox(self.controlArea, 'Settings') self.cb_image_attr = comboBox( widget=widget_box, master=self, value='cb_image_attr_current_id', label='Image attribute:', orientation=Qt.Horizontal, callback=self._cb_image_attr_changed ) self.auto_commit_widget = auto_commit( widget=self.controlArea, master=self, value='_auto_apply', label='Apply', checkbox_label='Auto Apply', commit=self.commit ) def set_data(self, data): if data is None: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self.input_data_info.setText(self._NO_DATA_INFO_TEXT) return self._image_attributes = self._filter_image_attributes(data) if not self._image_attributes: input_data_info_text = ( "Data with {:d} instances, but without image attributes." .format(len(data))) input_data_info_text.format(input_data_info_text) self.input_data_info.setText(input_data_info_text) self._input_data = None return if not self.cb_image_attr_current_id < len(self._image_attributes): self.cb_image_attr_current_id = 0 self.cb_image_attr.setModel(VariableListModel(self._image_attributes)) self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id) self._input_data = data input_data_info_text = "Data with {:d} instances.".format(len(data)) self.input_data_info.setText(input_data_info_text) self._cb_image_attr_changed() @staticmethod def _filter_image_attributes(data): metas = data.domain.metas return [m for m in metas if m.attributes.get('type') == 'image'] def _cb_image_attr_changed(self): if self._auto_apply: self.commit() def commit(self): if not self._image_attributes or not self._input_data: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) return self._set_server_info(connected=True) self.auto_commit_widget.setDisabled(True) file_paths_attr = self._image_attributes[self.cb_image_attr_current_id] file_paths = self._input_data[:, file_paths_attr].metas.flatten() with self.progressBar(len(file_paths)) as progress: try: embeddings = self._image_embedder( file_paths=file_paths, image_processed_callback=lambda: progress.advance() ) except ConnectionError: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self._set_server_info(connected=False) self.auto_commit_widget.setDisabled(False) return self._send_output_signals(embeddings) self.auto_commit_widget.setDisabled(False) def _send_output_signals(self, embeddings): skipped_images_bool = np.array([x is None for x in embeddings]) if np.any(skipped_images_bool): skipped_images = self._input_data[skipped_images_bool] skipped_images = Table(skipped_images) skipped_images.ids = self._input_data.ids[skipped_images_bool] self.send(_Output.SKIPPED_IMAGES, skipped_images) else: self.send(_Output.SKIPPED_IMAGES, None) embedded_images_bool = np.logical_not(skipped_images_bool) if np.any(embedded_images_bool): embedded_images = self._input_data[embedded_images_bool] embeddings = embeddings[embedded_images_bool] embeddings = np.stack(embeddings) embedded_images = self._construct_output_data_table( embedded_images, embeddings ) embedded_images.ids = self._input_data.ids[embedded_images_bool] self.send(_Output.EMBEDDINGS, embedded_images) else: self.send(_Output.EMBEDDINGS, None) @staticmethod def _construct_output_data_table(embedded_images, embeddings): X = np.hstack((embedded_images.X, embeddings)) Y = embedded_images.Y dimensions = range(embeddings.shape[1]) attributes = [ContinuousVariable('n{:d}'.format(d)) for d in dimensions] attributes = list(embedded_images.domain.attributes) + attributes domain = Domain( attributes=attributes, class_vars=embedded_images.domain.class_vars, metas=embedded_images.domain.metas ) return Table(domain, X, Y, embedded_images.metas) def _set_server_info(self, connected): self.clear_messages() if connected: self.connection_info.setText("Connected to server.") else: self.connection_info.setText("No connection with server.") self.warning("Click Apply to try again.") def onDeleteWidget(self): super().onDeleteWidget() self._image_embedder.__exit__(None, None, None)
class OWImageEmbedding(OWWidget): # todo: implement embedding in a non-blocking manner # todo: implement stop running task action name = "Image Embedding" description = "Image embedding through deep neural networks." icon = "icons/ImageEmbedding.svg" priority = 150 want_main_area = False _auto_apply = Setting(default=True) inputs = [(_Input.IMAGES, Table, 'set_data')] outputs = [(_Output.EMBEDDINGS, Table, Default), (_Output.SKIPPED_IMAGES, Table)] cb_image_attr_current_id = Setting(default=0) _NO_DATA_INFO_TEXT = "No data on input." def __init__(self): super().__init__() self._image_attributes = None self._input_data = None self._setup_layout() self._image_embedder = None QTimer.singleShot(0, self._init_server_connection) def _setup_layout(self): self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width()) self.layout().setSizeConstraint(QLayout.SetFixedSize) widget_box = widgetBox(self.controlArea, 'Info') self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT) self.connection_info = widgetLabel(widget_box, "") widget_box = widgetBox(self.controlArea, 'Settings') self.cb_image_attr = comboBox(widget=widget_box, master=self, value='cb_image_attr_current_id', label='Image attribute:', orientation=Qt.Horizontal, callback=self._cb_image_attr_changed) self.auto_commit_widget = auto_commit(widget=self.controlArea, master=self, value='_auto_apply', label='Apply', checkbox_label='Auto Apply', commit=self.commit) def _init_server_connection(self): self._image_embedder = ImageEmbedder( model='inception-v3', layer='penultimate', ) self._set_server_info(self._image_embedder.is_connected_to_server()) def set_data(self, data): if data is None: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self.input_data_info.setText(self._NO_DATA_INFO_TEXT) return self._image_attributes = self._filter_image_attributes(data) if not self._image_attributes: input_data_info_text = ( "Data with {:d} instances, but without image attributes.". format(len(data))) input_data_info_text.format(input_data_info_text) self.input_data_info.setText(input_data_info_text) self._input_data = None return if not self.cb_image_attr_current_id < len(self._image_attributes): self.cb_image_attr_current_id = 0 self.cb_image_attr.setModel(VariableListModel(self._image_attributes)) self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id) self._input_data = data input_data_info_text = "Data with {:d} instances.".format(len(data)) self.input_data_info.setText(input_data_info_text) self._cb_image_attr_changed() @staticmethod def _filter_image_attributes(data): metas = data.domain.metas return [m for m in metas if m.attributes.get('type') == 'image'] def _cb_image_attr_changed(self): if self._auto_apply: self.commit() def commit(self): if not self._image_attributes or not self._input_data: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) return self._set_server_info(connected=True) self.auto_commit_widget.setDisabled(True) file_paths_attr = self._image_attributes[self.cb_image_attr_current_id] file_paths = self._input_data[:, file_paths_attr].metas.flatten() with self.progressBar(len(file_paths)) as progress: try: embeddings = self._image_embedder( file_paths=file_paths, image_processed_callback=lambda: progress.advance()) except ConnectionError: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self._set_server_info(connected=False) self.auto_commit_widget.setDisabled(False) return self._send_output_signals(embeddings) self.auto_commit_widget.setDisabled(False) def _send_output_signals(self, embeddings): skipped_images_bool = np.array([x is None for x in embeddings]) if np.any(skipped_images_bool): skipped_images = self._input_data[skipped_images_bool] skipped_images = Table(skipped_images) skipped_images.ids = self._input_data.ids[skipped_images_bool] self.send(_Output.SKIPPED_IMAGES, skipped_images) else: self.send(_Output.SKIPPED_IMAGES, None) embedded_images_bool = np.logical_not(skipped_images_bool) if np.any(embedded_images_bool): embedded_images = self._input_data[embedded_images_bool] embeddings = embeddings[embedded_images_bool] embeddings = np.stack(embeddings) embedded_images = self._construct_output_data_table( embedded_images, embeddings) embedded_images.ids = self._input_data.ids[embedded_images_bool] self.send(_Output.EMBEDDINGS, embedded_images) else: self.send(_Output.EMBEDDINGS, None) @staticmethod def _construct_output_data_table(embedded_images, embeddings): X = np.hstack((embedded_images.X, embeddings)) Y = embedded_images.Y dimensions = range(embeddings.shape[1]) attributes = [ ContinuousVariable('n{:d}'.format(d)) for d in dimensions ] attributes = list(embedded_images.domain.attributes) + attributes domain = Domain(attributes=attributes, class_vars=embedded_images.domain.class_vars, metas=embedded_images.domain.metas) return Table(domain, X, Y, embedded_images.metas) def _set_server_info(self, connected): self.clear_messages() if connected: self.connection_info.setText("Connected to server.") else: self.connection_info.setText("No connection with server.") self.warning("Click Apply to try again.") def onDeleteWidget(self): super().onDeleteWidget() self._image_embedder.__exit__(None, None, None)
class OWImageEmbedding(OWWidget): name = "Image Embedding" description = "Image embedding through deep neural networks." icon = "icons/ImageEmbedding.svg" priority = 150 want_main_area = False _auto_apply = Setting(default=True) inputs = [(_Input.IMAGES, Table, 'set_data')] outputs = [(_Output.EMBEDDINGS, Table, Default), (_Output.SKIPPED_IMAGES, Table)] cb_image_attr_current_id = Setting(default=0) cb_embedder_current_id = Setting(default=0) _NO_DATA_INFO_TEXT = "No data on input." def __init__(self): super().__init__() self.embedders = sorted(list(EMBEDDERS_INFO)) self._image_attributes = None self._input_data = None self._log = logging.getLogger(__name__) self._task = None self._setup_layout() self._image_embedder = None self._executor = qconcurrent.ThreadExecutor( self, threadPool=QThreadPool(maxThreadCount=1)) self.setBlocking(True) QTimer.singleShot(0, self._init_server_connection) def _setup_layout(self): self.controlArea.setMinimumWidth(self.controlArea.sizeHint().width()) self.layout().setSizeConstraint(QLayout.SetFixedSize) widget_box = widgetBox(self.controlArea, 'Info') self.input_data_info = widgetLabel(widget_box, self._NO_DATA_INFO_TEXT) self.connection_info = widgetLabel(widget_box, "") widget_box = widgetBox(self.controlArea, 'Settings') self.cb_image_attr = comboBox(widget=widget_box, master=self, value='cb_image_attr_current_id', label='Image attribute:', orientation=Qt.Horizontal, callback=self._cb_image_attr_changed) self.cb_embedder = comboBox(widget=widget_box, master=self, value='cb_embedder_current_id', label='Embedder:', orientation=Qt.Horizontal, callback=self._cb_embedder_changed) self.cb_embedder.setModel( VariableListModel( [EMBEDDERS_INFO[e]['name'] for e in self.embedders])) if not self.cb_embedder_current_id < len(self.embedders): self.cb_embedder_current_id = 0 self.cb_embedder.setCurrentIndex(self.cb_embedder_current_id) current_embedder = self.embedders[self.cb_embedder_current_id] self.embedder_info = widgetLabel( widget_box, EMBEDDERS_INFO[current_embedder]['description']) self.auto_commit_widget = auto_commit(widget=self.controlArea, master=self, value='_auto_apply', label='Apply', commit=self.commit) self.cancel_button = QPushButton( 'Cancel', icon=self.style().standardIcon(QStyle.SP_DialogCancelButton), ) self.cancel_button.clicked.connect(self.cancel) hbox = hBox(self.controlArea) hbox.layout().addWidget(self.cancel_button) self.cancel_button.hide() def _init_server_connection(self): self.setBlocking(False) self._image_embedder = ImageEmbedder( model=self.embedders[self.cb_embedder_current_id], layer='penultimate') self._set_server_info(self._image_embedder.is_connected_to_server()) def set_data(self, data): if not data: self._input_data = None self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self.input_data_info.setText(self._NO_DATA_INFO_TEXT) return self._image_attributes = self._filter_image_attributes(data) if not self._image_attributes: input_data_info_text = ( "Data with {:d} instances, but without image attributes.". format(len(data))) input_data_info_text.format(input_data_info_text) self.input_data_info.setText(input_data_info_text) self._input_data = None return if not self.cb_image_attr_current_id < len(self._image_attributes): self.cb_image_attr_current_id = 0 self.cb_image_attr.setModel(VariableListModel(self._image_attributes)) self.cb_image_attr.setCurrentIndex(self.cb_image_attr_current_id) self._input_data = data input_data_info_text = "Data with {:d} instances.".format(len(data)) self.input_data_info.setText(input_data_info_text) self._cb_image_attr_changed() @staticmethod def _filter_image_attributes(data): metas = data.domain.metas return [m for m in metas if m.attributes.get('type') == 'image'] def _cb_image_attr_changed(self): self.commit() def _cb_embedder_changed(self): current_embedder = self.embedders[self.cb_embedder_current_id] self._image_embedder = ImageEmbedder(model=current_embedder, layer='penultimate') self.embedder_info.setText( EMBEDDERS_INFO[current_embedder]['description']) self.commit() def commit(self): if self._task is not None: self.cancel() if self._image_embedder is None: self._set_server_info(connected=False) return if not self._image_attributes or self._input_data is None: self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) return self._set_server_info(connected=True) self.cancel_button.show() self.cb_image_attr.setDisabled(True) self.cb_embedder.setDisabled(True) file_paths_attr = self._image_attributes[self.cb_image_attr_current_id] file_paths = self._input_data[:, file_paths_attr].metas.flatten() assert file_paths_attr.is_string assert file_paths.dtype == np.dtype('O') file_paths_mask = file_paths == file_paths_attr.Unknown file_paths_valid = file_paths[~file_paths_mask] ticks = iter(np.linspace(0.0, 100.0, file_paths_valid.size)) set_progress = qconcurrent.methodinvoke(self, "__progress_set", (float, )) def advance(): set_progress(next(ticks)) def cancel(): task.future.cancel() task.cancelled = True task.embedder.cancelled = True embedder = self._image_embedder def run_embedding(paths): return embedder(file_paths=paths, image_processed_callback=advance) self.auto_commit_widget.setDisabled(True) self.progressBarInit(processEvents=None) self.progressBarSet(0.0, processEvents=None) self.setBlocking(True) f = self._executor.submit(run_embedding, file_paths_valid) f.add_done_callback( qconcurrent.methodinvoke(self, "__set_results", (object, ))) task = self._task = namespace( file_paths_mask=file_paths_mask, file_paths_valid=file_paths_valid, file_paths=file_paths, embedder=embedder, cancelled=False, cancel=cancel, future=f, ) self._log.debug("Starting embedding task for %i images", file_paths.size) return @Slot(float) def __progress_set(self, value): assert self.thread() is QThread.currentThread() if self._task is not None: self.progressBarSet(value) @Slot(object) def __set_results(self, f): assert self.thread() is QThread.currentThread() if self._task is None or self._task.future is not f: self._log.info("Reaping stale task") return assert f.done() task, self._task = self._task, None self.auto_commit_widget.setDisabled(False) self.cancel_button.hide() self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self.progressBarFinished(processEvents=None) self.setBlocking(False) try: embeddings = f.result() except ConnectionError: self._log.exception("Error", exc_info=True) self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) self._set_server_info(connected=False) return except Exception as err: self._log.exception("Error", exc_info=True) self.error("\n".join( traceback.format_exception_only(type(err), err))) self.send(_Output.EMBEDDINGS, None) self.send(_Output.SKIPPED_IMAGES, None) return assert self._input_data is not None assert len(self._input_data) == len(task.file_paths_mask) # Missing paths/urls were filtered out. Restore the full embeddings # array from information stored in task.file_path_mask ... embeddings_all = [None] * len(task.file_paths_mask) for i, embedding in zip(np.flatnonzero(~task.file_paths_mask), embeddings): embeddings_all[i] = embedding embeddings_all = np.array(embeddings_all) self._send_output_signals(embeddings_all) def _send_output_signals(self, embeddings): skipped_images_bool = np.array([x is None for x in embeddings]) if np.any(skipped_images_bool): skipped_images = self._input_data[skipped_images_bool] skipped_images = Table(skipped_images) skipped_images.ids = self._input_data.ids[skipped_images_bool] self.send(_Output.SKIPPED_IMAGES, skipped_images) else: self.send(_Output.SKIPPED_IMAGES, None) embedded_images_bool = np.logical_not(skipped_images_bool) if np.any(embedded_images_bool): embedded_images = self._input_data[embedded_images_bool] embeddings = embeddings[embedded_images_bool] embeddings = np.stack(embeddings) embedded_images = self._construct_output_data_table( embedded_images, embeddings) embedded_images.ids = self._input_data.ids[embedded_images_bool] self.send(_Output.EMBEDDINGS, embedded_images) else: self.send(_Output.EMBEDDINGS, None) @staticmethod def _construct_output_data_table(embedded_images, embeddings): X = np.hstack((embedded_images.X, embeddings)) Y = embedded_images.Y attributes = [ ContinuousVariable.make('n{:d}'.format(d)) for d in range(embeddings.shape[1]) ] attributes = list(embedded_images.domain.attributes) + attributes domain = Domain(attributes=attributes, class_vars=embedded_images.domain.class_vars, metas=embedded_images.domain.metas) return Table(domain, X, Y, embedded_images.metas) def _set_server_info(self, connected): self.clear_messages() if connected: self.connection_info.setText("Connected to server.") else: self.connection_info.setText("No connection with server.") self.warning("Click Apply to try again.") def onDeleteWidget(self): self.cancel() super().onDeleteWidget() self._image_embedder.__exit__(None, None, None) def cancel(self): if self._task is not None: task, self._task = self._task, None task.cancel() # wait until done try: task.future.exception() except qconcurrent.CancelledError: pass self.auto_commit_widget.setDisabled(False) self.cancel_button.hide() self.progressBarFinished(processEvents=None) self.setBlocking(False) self.cb_image_attr.setDisabled(False) self.cb_embedder.setDisabled(False) self._image_embedder.cancelled = False # reset the connection. connected = self._image_embedder.reconnect_to_server() self._set_server_info(connected=connected)