def __init__(self, parent=None):
        super(DatasetTabWidget, self).__init__(parent)
        self.setCursor(QtCore.Qt.PointingHandCursor)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.data_grid = DatasetGridWidget()
        self.data_grid.new_dataset_action_signal.connect(
            self.btn_new_dataset_on_slot)
        self.data_grid.delete_dataset_action_signal.connect(
            self.btn_delete_dataset_on_slot)
        self.data_grid.refresh_dataset_action_signal.connect(
            self.refresh_dataset_action_slot)
        self.data_grid.edit_dataset_action_signal.connect(
            self.edit_dataset_action_slot)
        self.data_grid.open_dataset_action_signal.connect(
            self.open_dataset_action_slot)
        self.data_grid.download_anno_action_signal.connect(
            self.download_annot_action_slot)
        self.data_grid.import_anno_action_signal.connect(
            self.import_annot_action_slot)

        self.setWidget(self.data_grid)
        self.setWidgetResizable(True)
        self.thread_pool = QThreadPool()
        self.loading_dialog = QLoadingDialog()
        self._ds_dao = DatasetDao()
        self._labels_dao = LabelDao()
        self._annot_dao = AnnotaDao()
        self.load()
Beispiel #2
0
    def __init__(self, parent=None):
        super(ImageViewerWidget, self).__init__(parent)
        self.setupUi(self)
        self.viewer = ImageViewer()
        self.viewer.scene().itemAdded.connect(self._scene_item_added)
        self.center_layout.addWidget(self.viewer, 0, 0)

        # self._label_background=QLabel()
        # self._label_background.setFixedHeight(40)
        # image = GUIUtilities.get_image("label.png")
        # self._label_background.setPixmap(image.scaledToHeight(40))
        # self.center_layout.addWidget(self._label_background,0,0,QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self._label = QLabel()
        self._label.setVisible(False)
        self._label.setMargin(5)
        self._label.setStyleSheet('''
            QLabel{
            font: 12pt;
            border-radius: 25px;
            margin: 10px;
            color: black; 
            background-color: #FFFFDC;
            }
        ''')
        shadow = QGraphicsDropShadowEffect(self)
        shadow.setBlurRadius(8)
        # shadow.setColor(QtGui.QColor(76,35,45).lighter())
        shadow.setColor(QtGui.QColor(94, 93, 90).lighter())
        shadow.setOffset(2)
        self._label.setGraphicsEffect(shadow)
        self.center_layout.addWidget(self._label, 0, 0,
                                     QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self.actions_layout.setAlignment(QtCore.Qt.AlignCenter)
        self._ds_dao = DatasetDao()
        self._hub_dao = HubDao()
        self._labels_dao = LabelDao()
        self._annot_dao = AnnotaDao()
        self._thread_pool = QThreadPool()
        self._loading_dialog = QLoadingDialog()
        self._source = None
        self._image = None
        self.images_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        #self.images_list_widget.setSelectionMode(QAbstractItemView.SingleSelection)
        self.images_list_widget.currentItemChanged.connect(
            self.image_list_sel_changed_slot)
        self.images_list_widget.setContextMenuPolicy(
            QtCore.Qt.CustomContextMenu)
        self.images_list_widget.customContextMenuRequested.connect(
            self.image_list_context_menu)

        self.treeview_models = ModelsTreeview()
        self.treeview_models.setColumnWidth(0, 300)
        self.tree_view_models_layout.addWidget(self.treeview_models)
        self.treeview_models.action_click.connect(
            self.trv_models_action_click_slot)

        self.treeview_labels = LabelsTableView()
        self.treeview_labels.action_click.connect(
            self.trv_labels_action_click_slot)
        self.tree_view_labels_layout.addWidget(self.treeview_labels)
        self.treeview_labels.selectionModel().selectionChanged.connect(
            self.default_label_changed_slot)
        #window = GUIUtilities.findMainWindow()
        #window.keyPressed.connect(self.window_keyPressEvent)
        self.create_actions_bar()
Beispiel #3
0
class ImageViewerWidget(QWidget, Ui_Image_Viewer_Widget):
    def __init__(self, parent=None):
        super(ImageViewerWidget, self).__init__(parent)
        self.setupUi(self)
        self.viewer = ImageViewer()
        self.viewer.scene().itemAdded.connect(self._scene_item_added)
        self.center_layout.addWidget(self.viewer, 0, 0)

        # self._label_background=QLabel()
        # self._label_background.setFixedHeight(40)
        # image = GUIUtilities.get_image("label.png")
        # self._label_background.setPixmap(image.scaledToHeight(40))
        # self.center_layout.addWidget(self._label_background,0,0,QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self._label = QLabel()
        self._label.setVisible(False)
        self._label.setMargin(5)
        self._label.setStyleSheet('''
            QLabel{
            font: 12pt;
            border-radius: 25px;
            margin: 10px;
            color: black; 
            background-color: #FFFFDC;
            }
        ''')
        shadow = QGraphicsDropShadowEffect(self)
        shadow.setBlurRadius(8)
        # shadow.setColor(QtGui.QColor(76,35,45).lighter())
        shadow.setColor(QtGui.QColor(94, 93, 90).lighter())
        shadow.setOffset(2)
        self._label.setGraphicsEffect(shadow)
        self.center_layout.addWidget(self._label, 0, 0,
                                     QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self.actions_layout.setAlignment(QtCore.Qt.AlignCenter)
        self._ds_dao = DatasetDao()
        self._hub_dao = HubDao()
        self._labels_dao = LabelDao()
        self._annot_dao = AnnotaDao()
        self._thread_pool = QThreadPool()
        self._loading_dialog = QLoadingDialog()
        self._source = None
        self._image = None
        self.images_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        #self.images_list_widget.setSelectionMode(QAbstractItemView.SingleSelection)
        self.images_list_widget.currentItemChanged.connect(
            self.image_list_sel_changed_slot)
        self.images_list_widget.setContextMenuPolicy(
            QtCore.Qt.CustomContextMenu)
        self.images_list_widget.customContextMenuRequested.connect(
            self.image_list_context_menu)

        self.treeview_models = ModelsTreeview()
        self.treeview_models.setColumnWidth(0, 300)
        self.tree_view_models_layout.addWidget(self.treeview_models)
        self.treeview_models.action_click.connect(
            self.trv_models_action_click_slot)

        self.treeview_labels = LabelsTableView()
        self.treeview_labels.action_click.connect(
            self.trv_labels_action_click_slot)
        self.tree_view_labels_layout.addWidget(self.treeview_labels)
        self.treeview_labels.selectionModel().selectionChanged.connect(
            self.default_label_changed_slot)
        #window = GUIUtilities.findMainWindow()
        #window.keyPressed.connect(self.window_keyPressEvent)
        self.create_actions_bar()

    def image_list_context_menu(self, pos: QPoint):
        menu = QMenu()
        result = self._labels_dao.fetch_all(self.source.dataset)
        if len(result) > 0:
            labels_menu = menu.addMenu("labels")
            for vo in result:
                action = labels_menu.addAction(vo.name)
                action.setData(vo)
        action = menu.exec_(QCursor.pos())
        if action and isinstance(action.data(), LabelVO):
            label = action.data()
            self.change_image_labels(label)

    def change_image_labels(self, label: LabelVO):
        items = self.images_list_widget.selectedItems()
        selected_images = []
        for item in items:
            vo = item.tag
            selected_images.append(vo)

        @work_exception
        def do_work():
            self._ds_dao.tag_entries(selected_images, label)
            return 1, None

        @gui_exception
        def done_work(result):
            status, err = result
            if err:
                raise err

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    def default_label_changed_slot(self, selection: QItemSelection):
        selected_rows = self.treeview_labels.selectionModel().selectedRows(2)
        if len(selected_rows) > 0:
            index: QModelIndex = selected_rows[0]
            current_label: LabelVO = self.treeview_labels.model().data(index)
            self.viewer.current_label = current_label

    def image_list_sel_changed_slot(self, curr: CustomListWidgetItem,
                                    prev: CustomListWidgetItem):
        if curr: self.source = curr.tag

    @property
    def image(self):
        return self._image

    @property
    def source(self) -> DatasetEntryVO:
        return self._source

    @source.setter
    def source(self, value):
        if not isinstance(value, DatasetEntryVO):
            raise Exception("Invalid source")
        self._source = value
        image_path = self._source.file_path
        self._image = Image.open(image_path)
        self.viewer.pixmap = QPixmap(image_path)
        self.load_image_annotations()
        self.load_image_label()

    @gui_exception
    def load_images(self):
        @work_exception
        def do_work():
            entries = self._ds_dao.fetch_entries(self.source.dataset)
            return entries, None

        @gui_exception
        def done_work(result):
            data, error = result
            selected_item = None
            for vo in data:
                item = CustomListWidgetItem(vo.file_path)
                item.tag = vo
                if vo.file_path == self.source.file_path:
                    selected_item = item
                self.images_list_widget.addItem(item)
                self.images_list_widget.setCurrentItem(selected_item)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def load_models(self):
        @work_exception
        def do_work():
            results = self._hub_dao.fetch_all()
            return results, None

        @gui_exception
        def done_work(result):
            result, error = result
            if result:
                for model in result:
                    self.treeview_models.add_node(model)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def load_labels(self):
        @work_exception
        def do_work():
            results = self._labels_dao.fetch_all(self.source.dataset)
            return results, None

        @gui_exception
        def done_work(result):
            result, error = result
            if error is None:
                for entry in result:
                    self.treeview_labels.add_row(entry)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def load_image_annotations(self):
        @work_exception
        def do_work():
            results = self._annot_dao.fetch_all(self.source.id)
            return results, None

        @gui_exception
        def done_work(result):
            result, error = result
            if error:
                raise error
            img_bbox: QRectF = self.viewer.pixmap.sceneBoundingRect()
            offset = QPointF(img_bbox.width() / 2, img_bbox.height() / 2)
            for entry in result:
                try:
                    vo: AnnotaVO = entry
                    points = map(float, vo.points.split(","))
                    points = list(more_itertools.chunked(points, 2))
                    if vo.kind == "box":
                        x = points[0][0] - offset.x()
                        y = points[0][1] - offset.y()
                        w = math.fabs(points[0][0] - points[1][0])
                        h = math.fabs(points[0][1] - points[1][1])
                        roi: QRectF = QRectF(x, y, w, h)
                        rect = EditableBox(roi)
                        rect.label = vo.label
                        self.viewer.scene().addItem(rect)
                    elif vo.kind == "polygon":
                        polygon = EditablePolygon()
                        polygon.label = vo.label
                        self.viewer.scene().addItem(polygon)
                        for p in points:
                            polygon.addPoint(
                                QPoint(p[0] - offset.x(), p[1] - offset.y()))
                except Exception as ex:
                    print(ex)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def load_image_label(self):
        @work_exception
        def do_work():
            label = self._annot_dao.get_label(self.source.id)
            return label, None

        @gui_exception
        def done_work(result):
            label_name, error = result
            if error:
                raise error
            if label_name:
                self._label.setVisible(True)
                self._label.setText(label_name)
            else:
                self._label.setVisible(False)
                self._label.setText("")

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def add_repository(self):
        @work_exception
        def do_work(repo):
            hub_client = HubClientFactory.create(Framework.PyTorch)
            hub = hub_client.fetch_model(repo, force_reload=True)
            self._hub_dao.save(hub)
            return hub, None

        @gui_exception
        def done_work(result):
            self._loading_dialog.close()
            data, error = result
            if error is None:
                self.treeview_models.add_node(data)

        form = NewRepoForm()
        if form.exec_() == QDialog.Accepted:
            repository = form.result
            worker = Worker(do_work, repository)
            worker.signals.result.connect(done_work)
            self._thread_pool.start(worker)
            self._loading_dialog.exec_()

    def bind(self):
        self.load_images()
        self.load_models()
        self.load_labels()

    def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
        row = self.images_list_widget.currentRow()
        last_index = self.images_list_widget.count() - 1
        if event.key() == QtCore.Qt.Key_A:
            self.save_annotations()
            if row > 0:
                self.images_list_widget.setCurrentRow(row - 1)
            else:
                self.images_list_widget.setCurrentRow(last_index)
        if event.key() == QtCore.Qt.Key_D:
            self.save_annotations()
            if row < last_index:
                self.images_list_widget.setCurrentRow(row + 1)
            else:
                self.images_list_widget.setCurrentRow(0)
        if event.key() == QtCore.Qt.Key_W:
            self.viewer.selection_mode = SELECTION_MODE.POLYGON
        if event.key() == QtCore.Qt.Key_S:
            self.viewer.selection_mode = SELECTION_MODE.BOX
        super(ImageViewerWidget, self).keyPressEvent(event)

    @gui_exception
    def trv_models_action_click_slot(self, action: QAction):
        if action.text() == self.treeview_models.CTX_MENU_NEW_DATASET_ACTION:
            self.add_repository()
        elif action.text() == self.treeview_models.CTX_MENU_AUTO_LABEL_ACTION:
            current_node = action.data()  # model name
            parent_node = current_node.parent  # repo
            repo, model = parent_node.get_data(0), current_node.get_data(0)
            self.autolabel(repo, model)

    @gui_exception
    def trv_labels_action_click_slot(self, action: QAction):
        model = self.treeview_labels.model()
        if action.text() == self.treeview_labels.CTX_MENU_ADD_LABEL:
            form = NewLabelForm()
            if form.exec_() == QDialog.Accepted:
                label_vo: LabelVO = form.result
                label_vo.dataset = self.source.dataset
                label_vo = self._labels_dao.save(label_vo)
                self.treeview_labels.add_row(label_vo)
        elif action.text() == self.treeview_labels.CTX_MENU_DELETE_LABEL:
            index: QModelIndex = action.data()
            if index:
                label_vo = model.index(index.row(), 2).data()
                self._labels_dao.delete(label_vo.id)
                self.viewer.remove_annotations_by_label(label_vo.name)
                model.removeRow(index.row())

    def autolabel(self, repo, model_name):
        def do_work():
            try:
                print(repo, model_name)
                from PIL import Image
                from torchvision import transforms
                import torch
                model = torch.hub.load(repo, model_name, pretrained=True)
                model.eval()
                input_image = Image.open(self.source.file_path)
                preprocess = transforms.Compose([
                    transforms.Resize(480),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])
                input_tensor = preprocess(input_image)
                input_batch = input_tensor.unsqueeze(
                    0)  # create a mini-batch as expected by the model
                # move the param and model to GPU for speed if available
                if torch.cuda.is_available():
                    input_batch = input_batch.to('cuda')
                    model.to('cuda')
                with torch.no_grad():
                    output = model(input_batch)['out'][0]
                output_predictions = output.argmax(0)
                # create a color pallette, selecting a color for each class
                palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
                colors = torch.as_tensor([i for i in range(21)
                                          ])[:, None] * palette
                colors = (colors % 255).numpy().astype("uint8")
                # plot the semantic segmentation predictions of 21 classes in each color
                predictions_array: np.ndarray = output_predictions.byte().cpu(
                ).numpy()
                predictions_image = Image.fromarray(predictions_array).resize(
                    input_image.size)
                predictions_image.putpalette(colors)
                labels_mask = np.asarray(predictions_image)
                classes = list(
                    filter(lambda x: x != 0,
                           np.unique(labels_mask).tolist()))
                classes_map = {c: [] for c in classes}
                for c in classes:
                    class_mask = np.zeros(labels_mask.shape, dtype=np.uint8)
                    class_mask[np.where(labels_mask == c)] = 255
                    contour_list = cv2.findContours(class_mask.copy(),
                                                    cv2.RETR_LIST,
                                                    cv2.CHAIN_APPROX_SIMPLE)
                    contour_list = imutils.grab_contours(contour_list)
                    for contour in contour_list:
                        points = np.vstack(contour).squeeze().tolist()
                        classes_map[c].append(points)
                return classes_map, None
            except Exception as ex:
                return None, ex

        def done_work(result):
            self._loading_dialog.close()
            classes_map, err = result
            if err:
                return
            for class_idx, contours in classes_map.items():
                for c in contours:
                    points = []
                    for i in range(0, len(c), 10):
                        points.append(c[i])
                    polygon = EditablePolygon()
                    self.viewer._scene.addItem(polygon)
                    bbox: QRectF = self.viewer.pixmap.boundingRect()
                    offset = QPointF(bbox.width() / 2, bbox.height() / 2)
                    for point in points:
                        polygon.addPoint(
                            QPoint(point[0] - offset.x(),
                                   point[1] - offset.y()))

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)
        self._loading_dialog.exec_()

    def create_actions_bar(self):
        self.btn_enable_polygon_selection = ImageButton(
            icon=GUIUtilities.get_icon("polygon.png"), size=QSize(32, 32))
        self.btn_enable_rectangle_selection = ImageButton(
            icon=GUIUtilities.get_icon("square.png"), size=QSize(32, 32))
        self.btn_enable_free_selection = ImageButton(
            icon=GUIUtilities.get_icon("highlighter.png"), size=QSize(32, 32))
        self.btn_enable_none_selection = ImageButton(
            icon=GUIUtilities.get_icon("cursor.png"), size=QSize(32, 32))
        self.btn_save_annotations = ImageButton(
            icon=GUIUtilities.get_icon("save-icon.png"), size=QSize(32, 32))
        self.btn_clear_annotations = ImageButton(
            icon=GUIUtilities.get_icon("clean.png"), size=QSize(32, 32))

        self.actions_layout.addWidget(self.btn_enable_rectangle_selection)
        self.actions_layout.addWidget(self.btn_enable_polygon_selection)
        self.actions_layout.addWidget(self.btn_enable_free_selection)
        self.actions_layout.addWidget(self.btn_enable_none_selection)
        self.actions_layout.addWidget(self.btn_clear_annotations)
        self.actions_layout.addWidget(self.btn_save_annotations)

        self.btn_save_annotations.clicked.connect(
            self.btn_save_annotations_clicked_slot)
        self.btn_enable_polygon_selection.clicked.connect(
            self.btn_enable_polygon_selection_clicked_slot)
        self.btn_enable_rectangle_selection.clicked.connect(
            self.btn_enable_rectangle_selection_clicked_slot)
        self.btn_enable_free_selection.clicked.connect(
            self.btn_enable_free_selection_clicked_slot)
        self.btn_enable_none_selection.clicked.connect(
            self.btn_enable_none_selection_clicked_slot)
        self.btn_clear_annotations.clicked.connect(
            self.btn_clear_annotations_clicked_slot)

    def btn_clear_annotations_clicked_slot(self):
        self.viewer.remove_annotations()

    def btn_enable_polygon_selection_clicked_slot(self):
        self.viewer.selection_mode = SELECTION_MODE.POLYGON

    def btn_enable_rectangle_selection_clicked_slot(self):
        self.viewer.selection_mode = SELECTION_MODE.BOX

    def btn_enable_none_selection_clicked_slot(self):
        self.viewer.selection_mode = SELECTION_MODE.NONE

    def btn_enable_free_selection_clicked_slot(self):
        self.viewer.selection_mode = SELECTION_MODE.FREE

    def save_annotations(self):
        scene: QGraphicsScene = self.viewer.scene()
        self._annot_dao.delete(self.source.id)
        annot_list = []
        for item in scene.items():
            img_bbox: QRectF = self.viewer.pixmap.sceneBoundingRect()
            offset = QPointF(img_bbox.width() / 2, img_bbox.height() / 2)
            if isinstance(item, EditableBox):
                item_box: QRectF = item.sceneBoundingRect()
                x1 = math.floor(item_box.topLeft().x() + offset.x())
                y1 = math.floor(item_box.topRight().y() + offset.y())
                x2 = math.floor(item_box.bottomRight().x() + offset.x())
                y2 = math.floor(item_box.bottomRight().y() + offset.y())
                box = AnnotaVO()
                box.label = item.label.id if item.label else None
                box.entry = self.source.id
                box.kind = "box"
                box.points = ",".join(map(str, [x1, y1, x2, y2]))
                annot_list.append(box)
            elif isinstance(item, EditablePolygon):
                points = [[
                    math.floor(pt.x() + offset.x()),
                    math.floor(pt.y() + offset.y())
                ] for pt in item.points]
                points = np.asarray(points).flatten().tolist()
                poly = AnnotaVO()
                poly.label = item.label.id if item.label else None
                poly.entry = self.source.id
                poly.kind = "polygon"
                poly.points = ",".join(map(str, points))
                annot_list.append(poly)

        self._annot_dao.save(annot_list)

    def btn_save_annotations_clicked_slot(self):
        self.save_annotations()

    def _scene_item_added(self, item: QGraphicsItem):
        item.tag = self.source
    def __init__(self, parent=None):
        super(ImageViewerWidget, self).__init__(parent)
        self.setupUi(self)

        self.image_viewer = ImageViewer()
        # self.image_viewer.scene().itemAdded.connect(self._scene_item_added)
        # self.image_viewer.key_press_sgn.connect(self.viewer_keyPressEvent)
        self.image_viewer.points_selection_sgn.connect(
            self.extreme_points_selection_done_slot)
        self.center_layout.addWidget(self.image_viewer, 0, 0)

        self._class_label = QLabel()
        self._class_label.setVisible(False)
        self._class_label.setMargin(5)
        self._class_label.setStyleSheet('''
            QLabel{
            font: 12pt;
            border-radius: 25px;
            margin: 10px;
            color: black; 
            background-color: #FFFFDC;
            }
        ''')
        shadow = QGraphicsDropShadowEffect(self)
        shadow.setBlurRadius(8)
        shadow.setColor(QtGui.QColor(94, 93, 90).lighter())
        shadow.setOffset(2)
        self._class_label.setGraphicsEffect(shadow)
        self.center_layout.addWidget(self._class_label, 0, 0,
                                     QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self.actions_layout.setAlignment(QtCore.Qt.AlignTop
                                         | QtCore.Qt.AlignHCenter)
        self.actions_layout.setContentsMargins(0, 5, 0, 0)
        self.images_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        self.images_list_widget.currentItemChanged.connect(
            self.image_list_sel_changed_slot)
        self.images_list_widget.setContextMenuPolicy(
            QtCore.Qt.CustomContextMenu)
        self.images_list_widget.customContextMenuRequested.connect(
            self.image_list_context_menu)
        self.images_list_widget.setCursor(QtCore.Qt.PointingHandCursor)

        self.treeview_models = ModelsTreeview()
        self.treeview_models.setColumnWidth(0, 300)
        self.tree_view_models_layout.addWidget(self.treeview_models)
        self.treeview_models.action_click.connect(
            self.trv_models_action_click_slot)

        self.treeview_labels = LabelsTableView()
        self.treeview_labels.action_click.connect(
            self.trv_labels_action_click_slot)
        self.tree_view_labels_layout.addWidget(self.treeview_labels)
        self.treeview_labels.selectionModel().selectionChanged.connect(
            self.default_label_changed_slot)

        # image adjustment controls
        img_adjust_controls_layout = QFormLayout()
        self.img_adjust_page.setLayout(img_adjust_controls_layout)
        self._brightness_slider = DoubleSlider()
        self._brightness_slider.setMinimum(0.0)
        self._brightness_slider.setMaximum(100.0)
        self._brightness_slider.setSingleStep(0.5)
        self._brightness_slider.setValue(self.image_viewer.img_brightness)
        self._brightness_slider.setOrientation(QtCore.Qt.Horizontal)

        self._contrast_slider = DoubleSlider()
        self._contrast_slider.setMinimum(1.0)
        self._contrast_slider.setMaximum(3.0)
        self._contrast_slider.setSingleStep(0.1)
        self._contrast_slider.setValue(self.image_viewer.img_contrast)
        self._contrast_slider.setOrientation(QtCore.Qt.Horizontal)

        self._gamma_slider = DoubleSlider()
        self._gamma_slider.setMinimum(1.0)
        self._gamma_slider.setMaximum(5.0)
        self._gamma_slider.setSingleStep(0.1)
        self._gamma_slider.setValue(self.image_viewer.img_gamma)
        self._gamma_slider.setOrientation(QtCore.Qt.Horizontal)
        self._number_of_clusters_spin = QSpinBox()
        self._number_of_clusters_spin.setMinimum(2)
        self._number_of_clusters_spin.setValue(5)

        self._contrast_slider.doubleValueChanged.connect(
            self._update_contrast_slot)
        self._brightness_slider.doubleValueChanged.connect(
            self._update_brightness_slot)
        self._gamma_slider.doubleValueChanged.connect(self._update_gamma_slot)

        img_adjust_controls_layout.addRow(QLabel("Brightness:"),
                                          self._brightness_slider)
        img_adjust_controls_layout.addRow(QLabel("Contrast:"),
                                          self._contrast_slider)
        img_adjust_controls_layout.addRow(QLabel("Gamma:"), self._gamma_slider)
        img_adjust_controls_layout.addRow(QLabel("Clusters:"),
                                          self._number_of_clusters_spin)

        self.img_adjust_page.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
        self.img_adjust_page.customContextMenuRequested.connect(
            self._image_processing_tools_ctx_menu)

        self._ds_dao = DatasetDao()
        self._hub_dao = HubDao()
        self._labels_dao = LabelDao()
        self._ann_dao = AnnotaDao()
        self._thread_pool = QThreadPool()
        self._loading_dialog = QLoadingDialog()
        self._tag = None
        self._curr_channel = 0
        self._channels = []
        self._toolbox = ImageViewerToolBox()
        self._toolbox.onAction.connect(self.on_action_toolbox_slot)
        self.actions_layout.addWidget(self._toolbox)
class ImageViewerWidget(QWidget, Ui_Image_Viewer_Widget):
    def __init__(self, parent=None):
        super(ImageViewerWidget, self).__init__(parent)
        self.setupUi(self)

        self.image_viewer = ImageViewer()
        # self.image_viewer.scene().itemAdded.connect(self._scene_item_added)
        # self.image_viewer.key_press_sgn.connect(self.viewer_keyPressEvent)
        self.image_viewer.points_selection_sgn.connect(
            self.extreme_points_selection_done_slot)
        self.center_layout.addWidget(self.image_viewer, 0, 0)

        self._class_label = QLabel()
        self._class_label.setVisible(False)
        self._class_label.setMargin(5)
        self._class_label.setStyleSheet('''
            QLabel{
            font: 12pt;
            border-radius: 25px;
            margin: 10px;
            color: black; 
            background-color: #FFFFDC;
            }
        ''')
        shadow = QGraphicsDropShadowEffect(self)
        shadow.setBlurRadius(8)
        shadow.setColor(QtGui.QColor(94, 93, 90).lighter())
        shadow.setOffset(2)
        self._class_label.setGraphicsEffect(shadow)
        self.center_layout.addWidget(self._class_label, 0, 0,
                                     QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        self.actions_layout.setAlignment(QtCore.Qt.AlignTop
                                         | QtCore.Qt.AlignHCenter)
        self.actions_layout.setContentsMargins(0, 5, 0, 0)
        self.images_list_widget.setSelectionMode(
            QAbstractItemView.ExtendedSelection)
        self.images_list_widget.currentItemChanged.connect(
            self.image_list_sel_changed_slot)
        self.images_list_widget.setContextMenuPolicy(
            QtCore.Qt.CustomContextMenu)
        self.images_list_widget.customContextMenuRequested.connect(
            self.image_list_context_menu)
        self.images_list_widget.setCursor(QtCore.Qt.PointingHandCursor)

        self.treeview_models = ModelsTreeview()
        self.treeview_models.setColumnWidth(0, 300)
        self.tree_view_models_layout.addWidget(self.treeview_models)
        self.treeview_models.action_click.connect(
            self.trv_models_action_click_slot)

        self.treeview_labels = LabelsTableView()
        self.treeview_labels.action_click.connect(
            self.trv_labels_action_click_slot)
        self.tree_view_labels_layout.addWidget(self.treeview_labels)
        self.treeview_labels.selectionModel().selectionChanged.connect(
            self.default_label_changed_slot)

        # image adjustment controls
        img_adjust_controls_layout = QFormLayout()
        self.img_adjust_page.setLayout(img_adjust_controls_layout)
        self._brightness_slider = DoubleSlider()
        self._brightness_slider.setMinimum(0.0)
        self._brightness_slider.setMaximum(100.0)
        self._brightness_slider.setSingleStep(0.5)
        self._brightness_slider.setValue(self.image_viewer.img_brightness)
        self._brightness_slider.setOrientation(QtCore.Qt.Horizontal)

        self._contrast_slider = DoubleSlider()
        self._contrast_slider.setMinimum(1.0)
        self._contrast_slider.setMaximum(3.0)
        self._contrast_slider.setSingleStep(0.1)
        self._contrast_slider.setValue(self.image_viewer.img_contrast)
        self._contrast_slider.setOrientation(QtCore.Qt.Horizontal)

        self._gamma_slider = DoubleSlider()
        self._gamma_slider.setMinimum(1.0)
        self._gamma_slider.setMaximum(5.0)
        self._gamma_slider.setSingleStep(0.1)
        self._gamma_slider.setValue(self.image_viewer.img_gamma)
        self._gamma_slider.setOrientation(QtCore.Qt.Horizontal)
        self._number_of_clusters_spin = QSpinBox()
        self._number_of_clusters_spin.setMinimum(2)
        self._number_of_clusters_spin.setValue(5)

        self._contrast_slider.doubleValueChanged.connect(
            self._update_contrast_slot)
        self._brightness_slider.doubleValueChanged.connect(
            self._update_brightness_slot)
        self._gamma_slider.doubleValueChanged.connect(self._update_gamma_slot)

        img_adjust_controls_layout.addRow(QLabel("Brightness:"),
                                          self._brightness_slider)
        img_adjust_controls_layout.addRow(QLabel("Contrast:"),
                                          self._contrast_slider)
        img_adjust_controls_layout.addRow(QLabel("Gamma:"), self._gamma_slider)
        img_adjust_controls_layout.addRow(QLabel("Clusters:"),
                                          self._number_of_clusters_spin)

        self.img_adjust_page.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
        self.img_adjust_page.customContextMenuRequested.connect(
            self._image_processing_tools_ctx_menu)

        self._ds_dao = DatasetDao()
        self._hub_dao = HubDao()
        self._labels_dao = LabelDao()
        self._ann_dao = AnnotaDao()
        self._thread_pool = QThreadPool()
        self._loading_dialog = QLoadingDialog()
        self._tag = None
        self._curr_channel = 0
        self._channels = []
        self._toolbox = ImageViewerToolBox()
        self._toolbox.onAction.connect(self.on_action_toolbox_slot)
        self.actions_layout.addWidget(self._toolbox)

    @property
    def image(self):
        return self.image_viewer.image

    @image.setter
    def image(self, value):
        self.image_viewer.image = value

    @property
    def tag(self):
        return self._tag

    @tag.setter
    def tag(self, value):
        self._tag = value
        if self._tag:
            dataset_id = self._tag.dataset
            self.image_viewer.dataset = dataset_id

    @property
    def channels(self):
        return self._channels

    @property
    def curr_channel(self):
        return self._curr_channel

    @curr_channel.setter
    def curr_channel(self, value):
        self._curr_channel = value

    def on_action_toolbox_slot(self, action_tag):
        tools_dict = {
            "polygon": SELECTION_TOOL.POLYGON,
            "box": SELECTION_TOOL.BOX,
            "ellipse": SELECTION_TOOL.ELLIPSE,
            "free": SELECTION_TOOL.FREE,
            "points": SELECTION_TOOL.EXTREME_POINTS,
            "pointer": SELECTION_TOOL.POINTER
        }
        if action_tag in tools_dict:
            self.image_viewer.current_tool = tools_dict[action_tag]
        else:
            if action_tag == "save":

                @gui_exception
                def done_work(result):
                    _, err = result
                    if err:
                        raise err
                    GUIUtilities.show_info_message(
                        "Annotations saved successfully", "Information")

                self.save_annotations(done_work)
            elif action_tag == "clean":
                self.image_viewer.remove_annotations()

    @gui_exception
    def bind(self):
        @work_exception
        def do_work():
            return dask.compute(
                *[self.load_images(),
                  self.load_models(),
                  self.load_labels()]), None

        @gui_exception
        def done_work(args):
            result, error = args
            if result:
                images, models, labels = result
                if models:
                    for model in models:
                        self.treeview_models.add_node(model)
                if labels:
                    for entry in labels:
                        self.treeview_labels.add_row(entry)
                selected_image = None
                icon = GUIUtilities.get_icon("image.png")
                for img in images:
                    if os.path.isfile(img.file_path):
                        item = CustomListWidgetItem(img.file_path, tag=img)
                        item.setIcon(icon)
                        self.images_list_widget.addItem(item)
                        if img.file_path == self.tag.file_path:
                            selected_image = item
                    self.images_list_widget.setCurrentItem(selected_image)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @dask.delayed
    def load_images(self) -> [DatasetEntryVO]:
        dataset_id = self.tag.dataset
        return self._ds_dao.fetch_entries(dataset_id)

    @dask.delayed
    def load_models(self) -> [HubVO]:
        return self._hub_dao.fetch_all()

    @dask.delayed
    def load_labels(self):
        dataset_id = self.tag.dataset
        return self._labels_dao.fetch_all(dataset_id)

    @gui_exception
    def image_list_sel_changed_slot(self, curr: CustomListWidgetItem,
                                    prev: CustomListWidgetItem):
        self.image, self.tag = cv2.imread(curr.tag.file_path,
                                          cv2.IMREAD_COLOR), curr.tag
        self.load_image()

    @gui_exception
    def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
        row = self.images_list_widget.currentRow()
        last_index = self.images_list_widget.count() - 1
        if event.key() == QtCore.Qt.Key_A:

            @gui_exception
            def done_work(result):
                _, err = result
                if err:
                    raise err
                if row > 0:
                    self.images_list_widget.setCurrentRow(row - 1)
                else:
                    self.images_list_widget.setCurrentRow(last_index)

            self.save_annotations(done_work)
        elif event.key() == QtCore.Qt.Key_D:

            @gui_exception
            def done_work(result):
                _, err = result
                if err:
                    raise err
                if row < last_index:
                    self.images_list_widget.setCurrentRow(row + 1)
                else:
                    self.images_list_widget.setCurrentRow(0)

            self.save_annotations(done_work)
        super(ImageViewerWidget, self).keyPressEvent(event)

    def _update_contrast_slot(self, val):
        self.image_viewer.img_contrast = val
        self.image_viewer.update_viewer(fit_image=False)

    def _update_gamma_slot(self, val):
        self.image_viewer.img_gamma = val
        self.image_viewer.update_viewer(fit_image=False)

    def _update_brightness_slot(self, val):
        self.image_viewer.img_brightness = val
        self.image_viewer.update_viewer(fit_image=False)

    def _reset_sliders(self):
        self._gamma_slider.setValue(1.0)
        self._brightness_slider.setValue(50.0)
        self._contrast_slider.setValue(1.0)

    @gui_exception
    def _image_processing_tools_ctx_menu(self, pos: QPoint):
        menu = QMenu(self)
        action1 = QAction("Reset Image")
        action1.setData("reset")
        action2 = QAction("Equalize Histogram")
        action2.setData("equalize_histo")
        action3 = QAction("Correct Lightness")
        action3.setData("correct_l")
        action4 = QAction("Cluster Image")
        action4.setData("clustering")
        menu.addActions([action1, action2, action3, action4])
        action = menu.exec_(self.img_adjust_page.mapToGlobal(pos))
        if action:
            self._process_image_adjust_oper(action)

    def _process_image_adjust_oper(self, action: QAction):
        curr_action = action.data()
        k = self._number_of_clusters_spin.value()

        @work_exception
        def do_work():
            if curr_action == "reset":
                self._reset_sliders()
                self.image_viewer.reset_viewer()
            elif curr_action == "equalize_histo":
                self.image_viewer.equalize_histogram()
            elif curr_action == "correct_l":
                self.image_viewer.correct_lightness()
            elif curr_action == "clustering":
                self.image_viewer.clusterize(k=k)

            return None, None

        @gui_exception
        def done_work(result):
            out, err = result
            if err:
                return
            self._loading_dialog.hide()
            self.image_viewer.update_viewer()

        self._loading_dialog.show()
        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    def default_label_changed_slot(self, selection: QItemSelection):
        selected_rows = self.treeview_labels.selectionModel().selectedRows(2)
        if len(selected_rows) > 0:
            index: QModelIndex = selected_rows[0]
            current_label: LabelVO = self.treeview_labels.model().data(index)
            self.image_viewer.current_label = current_label

    @gui_exception
    def extreme_points_selection_done_slot(self, points: []):
        self.predict_annotations_using_extr_points(points)

    @gui_exception
    def trv_labels_action_click_slot(self, action: QAction):
        model = self.treeview_labels.model()
        if action.text() == self.treeview_labels.CTX_MENU_ADD_LABEL:
            form = NewLabelForm()
            if form.exec_() == QDialog.Accepted:
                label_vo: LabelVO = form.result
                label_vo.dataset = self.tag.dataset
                label_vo = self._labels_dao.save(label_vo)
                self.treeview_labels.add_row(label_vo)
        elif action.text() == self.treeview_labels.CTX_MENU_DELETE_LABEL:
            reply = QMessageBox.question(self, 'Delete Label', "Are you sure?",
                                         QMessageBox.Yes | QMessageBox.No,
                                         QMessageBox.No)
            if reply == QMessageBox.No:
                return
            index: QModelIndex = action.data()
            if index:
                label_vo = model.index(index.row(), 2).data()
                self._labels_dao.delete(
                    label_vo.id
                )  # TODO: this method should be call from other thread
                self.image_viewer.remove_annotations_by_label(label_vo.name)
                model.removeRow(index.row())

    @gui_exception
    def add_repository(self, repo_path):
        @work_exception
        def do_work():
            hub_client = HubClientFactory.create(Framework.PyTorch)
            hub = hub_client.fetch_model(repo_path, force_reload=True)
            id = self._hub_dao.save(hub)
            hub.id = id
            return hub, None

        @gui_exception
        def done_work(result):
            self._loading_dialog.close()
            hub, error = result
            if error:
                raise error
            self.treeview_models.add_node(hub)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)
        self._loading_dialog.exec_()

    @gui_exception
    def update_repository(self, hub: HubVO):
        hub_path = hub.path
        hub_id = hub.id

        @work_exception
        def do_work():
            hub_client = HubClientFactory.create(Framework.PyTorch)
            hub = hub_client.fetch_model(hub_path, force_reload=True)
            hub.id = hub_id
            self._hub_dao.update(hub)
            return hub, None

        @gui_exception
        def done_work(result):
            self._loading_dialog.close()
            hub, error = result
            if error:
                raise error
            self.treeview_models.add_node(hub)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)
        self._loading_dialog.exec_()

    @gui_exception
    def trv_models_action_click_slot(self, action: QAction):
        model = self.treeview_models.model()
        if action.text(
        ) == self.treeview_models.CTX_MENU_ADD_REPOSITORY_ACTION:
            form = NewRepoForm()
            if form.exec_() == QDialog.Accepted:
                repository = form.result
                self.add_repository(repository)
        elif action.text() == self.treeview_models.CTX_MENU_UPDATE_REPO_ACTION:
            index: QModelIndex = action.data()
            node: CustomNode = index.internalPointer()
            hub: HubVO = node.tag
            if hub and hub.id:
                model.removeChild(index)
                self.update_repository(hub)
        elif action.text() == self.treeview_models.CTX_MENU_AUTO_LABEL_ACTION:
            index: QModelIndex = action.data()
            node: CustomNode = index.internalPointer()
            parent_node = node.parent  # repo
            repo, model_name = parent_node.get_data(0), node.get_data(0)
            self.predict_annotations_using_pytorch_thub_model(repo, model_name)
        if action.text() == self.treeview_models.CTX_MENU_DELETE_REPO_ACTION:
            reply = QMessageBox.question(self, 'Delete Repository',
                                         "Are you sure?",
                                         QMessageBox.Yes | QMessageBox.No,
                                         QMessageBox.No)
            if reply == QMessageBox.No:
                return
            index: QModelIndex = action.data()
            node: CustomNode = index.internalPointer()
            hub: HubVO = node.tag
            if hub:
                self._hub_dao.delete(
                    hub.id
                )  # TODO: this method should be call from other thread
                model.removeChild(index)

    def image_list_context_menu(self, pos: QPoint):
        menu = QMenu()
        result = self._labels_dao.fetch_all(self.tag.dataset)
        if len(result) > 0:
            labels_menu = menu.addMenu("labels")
            for vo in result:
                action = labels_menu.addAction(vo.name)
                action.setData(vo)
        action = menu.exec_(QCursor.pos())
        if action and isinstance(action.data(), LabelVO):
            label = action.data()
            self.change_image_labels(label)

    def change_image_labels(self, label: LabelVO):
        items = self.images_list_widget.selectedItems()
        selected_images = []
        for item in items:
            vo = item.tag
            selected_images.append(vo)

        @work_exception
        def do_work():
            self._ds_dao.tag_entries(selected_images, label)
            return 1, None

        @gui_exception
        def done_work(result):
            status, err = result
            if err:
                raise err

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @dask.delayed
    def load_image_label(self):
        return self._ann_dao.get_label(self.tag.id)

    @dask.delayed
    def load_image_annotations(self):
        return self._ann_dao.fetch_all(self.tag.id)

    def load_image(self):
        @work_exception
        def do_work():
            return dask.compute(
                *[self.load_image_label(),
                  self.load_image_annotations()]), None

        @gui_exception
        def done_work(args):
            result, error = args
            if result:
                label, annotations = result
                if label:
                    self._class_label.setVisible(True)
                    self._class_label.setText(label)
                else:
                    self._class_label.setVisible(False)
                    self._class_label.setText("")

                if annotations:
                    img_bbox: QRectF = self.image_viewer.pixmap.sceneBoundingRect(
                    )
                    offset = QPointF(img_bbox.width() / 2,
                                     img_bbox.height() / 2)
                    for entry in annotations:
                        try:
                            vo: AnnotaVO = entry
                            points = map(float, vo.points.split(","))
                            points = list(more_itertools.chunked(points, 2))
                            if vo.kind == "box" or vo.kind == "ellipse":
                                x = points[0][0] - offset.x()
                                y = points[0][1] - offset.y()
                                w = math.fabs(points[0][0] - points[1][0])
                                h = math.fabs(points[0][1] - points[1][1])
                                roi: QRectF = QRectF(x, y, w, h)
                                if vo.kind == "box":
                                    item = EditableBox(roi)
                                else:
                                    item = EditableEllipse()
                                item.tag = self.tag.dataset
                                item.setRect(roi)
                                item.label = vo.label
                                self.image_viewer.scene().addItem(item)
                            elif vo.kind == "polygon":
                                item = EditablePolygon()
                                item.label = vo.label
                                item.tag = self.tag.dataset
                                self.image_viewer.scene().addItem(item)
                                for p in points:
                                    item.addPoint(
                                        QPoint(p[0] - offset.x(),
                                               p[1] - offset.y()))
                        except Exception as ex:
                            GUIUtilities.show_error_message(
                                "Error loading the annotations: {}".format(ex),
                                "Error")

        self.image_viewer.remove_annotations()
        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def save_annotations(self, done_work_callback):
        scene: QGraphicsScene = self.image_viewer.scene()
        annotations = []
        for item in scene.items():
            image_rect: QRectF = self.image_viewer.pixmap.sceneBoundingRect()
            image_offset = QPointF(image_rect.width() / 2,
                                   image_rect.height() / 2)
            if isinstance(item, EditableItem):
                a = AnnotaVO()
                a.label = item.label.id if item.label else None
                a.entry = self.tag.id
                a.kind = item.shape_type
                a.points = item.coordinates(image_offset)
                annotations.append(a)

        @work_exception
        def do_work():
            self._ann_dao.save(self.tag.id, annotations)
            return None, None

        worker = Worker(do_work)
        worker.signals.result.connect(done_work_callback)
        self._thread_pool.start(worker)

    @staticmethod
    def invoke_tf_hub_model(image_path, repo, model_name):
        from PIL import Image
        from torchvision import transforms
        import torch
        gpu_id = 0
        device = torch.device(
            "cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu")
        model = torch.hub.load(repo, model_name, pretrained=True)
        model.eval()
        input_image = Image.open(image_path)
        preprocess = transforms.Compose([
            transforms.Resize(480),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(
            0)  # create a mini-batch as expected by the model
        # # move the param and model to GPU for speed if available
        if torch.cuda.is_available():
            input_batch = input_batch.to(device)
            model.to(device)
        with torch.no_grad():
            output = model(input_batch)
        if isinstance(output, OrderedDict):
            output = output["out"][0]
            predictions_tensor = output.argmax(0)
            # move predictions to the cpu and convert into a numpy array
            predictions_arr: np.ndarray = predictions_tensor.byte().cpu(
            ).numpy()
            classes_ids = np.unique(predictions_arr).tolist()
            classes_idx = list(filter(lambda x: x != 0,
                                      classes_ids))  ## ignore
            predictions_arr = Image.fromarray(predictions_arr).resize(
                input_image.size)
            predictions_arr = np.asarray(predictions_arr)
            # 0 value
            predicted_mask = {c: [] for c in classes_idx}

            for idx in classes_idx:
                class_mask = np.zeros(predictions_arr.shape, dtype=np.uint8)
                class_mask[np.where(predictions_arr == idx)] = 255
                contour_list = cv2.findContours(class_mask.copy(),
                                                cv2.RETR_LIST,
                                                cv2.CHAIN_APPROX_SIMPLE)
                contour_list = imutils.grab_contours(contour_list)
                for contour in contour_list:
                    points = np.vstack(contour).squeeze().tolist()
                    predicted_mask[idx].append(points)
                return "mask", predicted_mask
        else:
            class_map = json.load(open("./data/imagenet_class_index.json"))
            max, argmax = output.data.squeeze().max(0)
            class_id = argmax.item()
            predicted_label = class_map[str(class_id)]
            return "label", predicted_label

        return None

    @staticmethod
    def invoke_dextr_pascal_model(image_path, points):
        import torch
        from collections import OrderedDict
        from PIL import Image
        import numpy as np
        from torch.nn.functional import upsample
        from contrib.dextr import deeplab_resnet as resnet
        from contrib.dextr import helpers
        modelName = 'dextr_pascal-sbd'
        pad = 50
        thres = 0.8
        gpu_id = 0
        device = torch.device(
            "cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu")
        #  Create the network and load the weights
        model = resnet.resnet101(1, nInputChannels=4, classifier='psp')
        model_path = os.path.abspath("./models/{}.pth".format(modelName))
        state_dict_checkpoint = torch.load(
            model_path, map_location=lambda storage, loc: storage)
        if 'module.' in list(state_dict_checkpoint.keys())[0]:
            new_state_dict = OrderedDict()
            # remove `module.` from multi-gpu training
            for k, v in state_dict_checkpoint.items():
                name = k[7:]
                new_state_dict[name] = v
        else:
            new_state_dict = state_dict_checkpoint
        model.load_state_dict(new_state_dict)
        model.eval()
        model.to(device)
        #  Read image and click the points
        image = np.array(Image.open(image_path))
        extreme_points_ori = np.asarray(points).astype(np.int)
        with torch.no_grad():
            #  Crop image to the bounding box from the extreme points and resize
            bbox = helpers.get_bbox(image,
                                    points=extreme_points_ori,
                                    pad=pad,
                                    zero_pad=True)
            crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
            resize_image = helpers.fixed_resize(crop_image,
                                                (512, 512)).astype(np.float32)

            #  Generate extreme point heat map normalized to image values
            extreme_points = extreme_points_ori - [
                np.min(extreme_points_ori[:, 0]),
                np.min(extreme_points_ori[:, 1])
            ] + [pad, pad]
            extreme_points = (
                512 * extreme_points *
                [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(
                    np.int)
            extreme_heatmap = helpers.make_gt(resize_image,
                                              extreme_points,
                                              sigma=10)
            extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)

            #  Concatenate inputs and convert to tensor
            input_dextr = np.concatenate(
                (resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
            inputs = torch.from_numpy(
                input_dextr.transpose((2, 0, 1))[np.newaxis, ...])

            # Run a forward pass
            inputs = inputs.to(device)
            outputs = model.forward(inputs)
            outputs = upsample(outputs,
                               size=(512, 512),
                               mode='bilinear',
                               align_corners=True)
            outputs = outputs.to(torch.device('cpu'))

            pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
            pred = 1 / (1 + np.exp(-pred))
            pred = np.squeeze(pred)
            result = helpers.crop2fullmask(
                pred, bbox, im_size=image.shape[:2], zero_pad=True,
                relax=pad) > thres
            binary = np.zeros_like(result, dtype=np.uint8)
            binary[result] = 255
            contour_list = cv2.findContours(binary.copy(), cv2.RETR_LIST,
                                            cv2.CHAIN_APPROX_SIMPLE)
            contour_list = imutils.grab_contours(contour_list)
            contours = []
            for contour in contour_list:
                c_points = np.vstack(contour).squeeze().tolist()
                contours.append(c_points)
        return contours

    @staticmethod
    def invoke_cov19_model(image_path, repo, model_name):
        from PIL import Image
        from torchvision import transforms
        import torch

        model = torch.hub.load(repo,
                               model_name,
                               pretrained=True,
                               force_reload=False)
        model.eval()
        input_image = Image.open(image_path)
        input_image = input_image.convert("RGB")
        preprocess = transforms.Compose([
            transforms.Resize(255),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        input_tensor = preprocess(input_image)
        input_batch = input_tensor.unsqueeze(0)
        output = model(input_batch)
        _, preds_tensor = torch.max(output, 1)
        top_pred = np.squeeze(preds_tensor.numpy())
        labels_map = {0: "conv19", 1: "normal"}
        class_idx = top_pred.item()
        # show top class
        return "label", [class_idx, labels_map[class_idx]]

    @gui_exception
    def predict_annotations_using_pytorch_thub_model(self, repo, model_name):
        @work_exception
        def do_work():
            if model_name == "covid19":
                pred_result = self.invoke_cov19_model(self.tag.file_path, repo,
                                                      model_name)
            else:
                pred_result = self.invoke_tf_hub_model(self.tag.file_path,
                                                       repo, model_name)
            return pred_result, None

        @gui_exception
        def done_work(result):
            self._loading_dialog.hide()
            pred_out, err = result
            if err:
                raise err
            if pred_out:
                pred_type, pred_res = pred_out
                if pred_type == "mask":
                    for class_idx, contours in pred_res.items():
                        if len(contours) > 0:
                            for c in contours:
                                points = []
                                for i in range(0, len(c), 10):
                                    points.append(c[i])
                                if len(points) > 5:
                                    polygon = EditablePolygon()
                                    polygon.tag = self.tag.dataset
                                    self.image_viewer._scene.addItem(polygon)
                                    bbox: QRectF = self.image_viewer.pixmap.boundingRect(
                                    )
                                    offset = QPointF(bbox.width() / 2,
                                                     bbox.height() / 2)
                                    for point in points:
                                        if isinstance(
                                                point,
                                                list) and len(point) == 2:
                                            polygon.addPoint(
                                                QPoint(point[0] - offset.x(),
                                                       point[1] - offset.y()))
                else:
                    class_id, class_name = pred_res
                    GUIUtilities.show_info_message(
                        "predicted label : `{}`".format(class_name),
                        "prediction result")

        self._loading_dialog.show()
        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)

    @gui_exception
    def predict_annotations_using_extr_points(self, points):
        pass

        @work_exception
        def do_work():
            contours = self.invoke_dextr_pascal_model(self.tag.file_path,
                                                      points)
            return contours, None

        @gui_exception
        def done_work(result):
            self._loading_dialog.hide()
            pred_out, err = result
            if err:
                raise err
            if pred_out:
                for c in pred_out:
                    c_points = []
                    for i in range(0, len(c), 10):
                        c_points.append(c[i])
                    if len(c_points) > 5:
                        polygon = EditablePolygon()
                        polygon.tag = self.tag.dataset
                        self.image_viewer._scene.addItem(polygon)
                        bbox: QRectF = self.image_viewer.pixmap.boundingRect()
                        offset = QPointF(bbox.width() / 2, bbox.height() / 2)
                        for point in c_points:
                            polygon.addPoint(
                                QPoint(point[0] - offset.x(),
                                       point[1] - offset.y()))

        self._loading_dialog.show()
        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self._thread_pool.start(worker)
Beispiel #6
0
class DatasetTabWidget(QScrollArea):
    JSON = "JSON"
    PASCAL_VOC = "Pascal VOC"
    TENSORFLOW_OBJECT_DETECTION = "TensorFlow Object Detection"
    YOLO = "YOLO"

    def __init__(self, parent=None):
        super(DatasetTabWidget, self).__init__(parent)
        self.setCursor(QtCore.Qt.PointingHandCursor)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.data_grid = DatasetGridWidget()
        self.data_grid.new_dataset_action_signal.connect(self.btn_new_dataset_on_slot)
        self.data_grid.delete_dataset_action_signal.connect(self.btn_delete_dataset_on_slot)
        self.data_grid.refresh_dataset_action_signal.connect(self.refresh_dataset_action_slot)
        self.data_grid.edit_dataset_action_signal.connect(self.edit_dataset_action_slot)
        self.data_grid.open_dataset_action_signal.connect(self.open_dataset_action_slot)
        self.data_grid.download_anno_action_signal.connect(self.download_annot_action_slot)
        self.setWidget(self.data_grid)
        self.setWidgetResizable(True)
        self.thread_pool = QThreadPool()
        self.loading_dialog = QLoadingDialog()
        self.ds_dao = DatasetDao()
        self.annot_dao = AnnotaDao()
        self.load()

    @gui_exception
    def load(self):
        @work_exception
        def do_work():
            results = self.ds_dao.fetch_all_with_size()
            return results, None

        @gui_exception
        def done_work(result):
            data, error = result
            if error:
                raise error
            self.data_grid.data_source = data
            self.data_grid.bind()

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self.thread_pool.start(worker)

    def _close_tab(self, tab_class):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        for i in range(tab_widget_manager.count()):
            curr_tab_widget = tab_widget_manager.widget(i)
            if isinstance(curr_tab_widget, tab_class):
                tab_widget_manager.removeTab(i)

    @QtCore.pyqtSlot()
    @gui_exception
    def btn_new_dataset_on_slot(self):
        form = DatasetForm()
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self.ds_dao.save(vo)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    @gui_exception
    def btn_delete_dataset_on_slot(self, vo: DatasetVO):
        reply = QMessageBox.question(self, 'Confirmation', "Are you sure?", QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.ds_dao.delete(vo.id)
            usr_folder = FileUtilities.get_usr_folder()
            ds_folder = os.path.join(usr_folder, vo.folder)
            FileUtilities.delete_folder(ds_folder)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    @gui_exception
    def edit_dataset_action_slot(self, vo: DatasetVO):
        form = DatasetForm(vo)
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self.ds_dao.save(vo)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    def refresh_dataset_action_slot(self, vo: DatasetVO):
        self.load()

    @QtCore.pyqtSlot(DatasetVO)
    def open_dataset_action_slot(self, vo: DatasetVO):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        tab_widget = MediaTabWidget(vo)
        #self._close_tab(MediaTabWidget)
        for i in range(tab_widget_manager.count()):
            tab_widget_manager.removeTab(i)
        index = tab_widget_manager.addTab(tab_widget, vo.name)
        tab_widget_manager.setCurrentIndex(index)

    def annotations2json(self, images, selected_folder):
        def export_template(img_path, img_annotations):
            str_template = '''
            {
              "path": "${path}",
              "regions": [
                % for i, region in enumerate(annotations):
                    {
                        "kind": "${region["annot_kind"]}",
                        "points": "${region["annot_points"]}",
                        "label": "${region["label_name"]}",
                        "color": "${region["label_color"]}"
                    } 
                    % if i < len(annotations) - 1:
                    ,
                    % endif
                % endfor
              ]
            }
            '''
            json_str = Template(str_template).render(path=img_path,annotations=list(img_annotations))
            filename =  os.path.split(img_path)[1]
            file_name, _ = os.path.splitext(filename)
            output_file = os.path.join(selected_folder, "{}.json".format(file_name))
            with open(output_file,'w') as f:
                json.dump(json.loads(json_str),f, indent=3)

        delayed_tasks = []
        for img_path, img_annotations in images:
            delayed_tasks.append(dask.delayed(export_template)(img_path, img_annotations))
        dask.compute(*delayed_tasks)


    def annotations2pascal(self, images, selected_folder):
        def export_template(img_path, img_annotations):
            str_template = '''
            <annotation>
                <folder>${folder}</folder>
                <filename>${filename}</filename>
                <path>${path}</path>
                <source>
                    <database>Unknown</database>
                </source>
                <size>
                    <width>${width}</width>
                    <height>${height}</height>
                    <depth>${depth}</depth>
                </size>
                <segmented>0</segmented>
                % for i, region in enumerate(annotations):
                    <object>
                        <name>${region["name"]}</name>
                        <pose>Unspecified</pose>
                        <truncated>0</truncated>
                        <difficult>0</difficult>
                        <bndbox>
                            <xmin>${region["xmin"]}</xmin>
                            <ymin>${region["ymin"]}</ymin>
                            <xmax>${region["xmax"]}</xmax>
                            <ymax>${region["ymax"]}</ymax>
                        </bndbox>
                    </object>
                % endfor
            </annotation>
            '''
            filename=os.path.split(img_path)[1]
            folder = os.path.split(os.path.dirname(img_path))[1]
            h,w,c = cv2.imread(img_path).shape

            xml_str = Template(str_template).render(
                path=img_path,
                folder=folder,
                filename= filename,
                width=w,
                height=h,
                depth=c,
                annotations=img_annotations)

            file_name,_=os.path.splitext(filename)
            output_file=os.path.join(selected_folder,"{}.xml".format(file_name))
            with open(output_file,'w') as f:
                f.write(xml_str)

        delayed_tasks=[]
        for img_path, img_annotations in images:
            boxes = []
            for annot in img_annotations:
                if annot["annot_kind"] == "box":
                    points=list(map(int,annot["annot_points"].split(",")))
                    box = dict()
                    box["name"] = annot["label_name"]
                    box["xmin"]=points[0]
                    box["ymin"]=points[1]
                    box["xmax"]=points[2]
                    box["ymax"]=points[3]
                    boxes.append(box)
            if len(boxes) > 0:
                delayed_tasks.append(dask.delayed(export_template)(img_path,boxes))
        dask.compute(*delayed_tasks)


    def annotations2Yolo(self, images, selected_folder):
        pass



    @gui_exception
    def download_annot_action_slot(self, vo: DatasetVO):
        menu=QMenu()
        menu.setCursor(QtCore.Qt.PointingHandCursor)
        menu.addAction(self.JSON)
        menu.addAction(self.PASCAL_VOC)
        #menu.addAction(self.TENSORFLOW_OBJECT_DETECTION)
        #menu.addAction(self.YOLO)
        action=menu.exec_(QCursor.pos())
        if action:

            selected_folder=str(QFileDialog.getExistingDirectory(None,"select the folder"))
            if selected_folder:
                action_text = action.text()
                @work_exception
                def do_work():
                    results = self.annot_dao.fetch_all_by_dataset(vo.id)
                    return results, None

                @gui_exception
                def done_work(result):
                    data, error = result
                    if error:
                        raise error
                    images =itertools.groupby(data,lambda x: x["image"])
                    if action_text == self.JSON:
                        self.annotations2json(images, selected_folder)
                    elif action_text == self.PASCAL_VOC:
                        self.annotations2pascal(images,selected_folder)

                    GUIUtilities.show_info_message("Annotations exported successfully", "Done")

                worker = Worker(do_work)
                worker.signals.result.connect(done_work)
                self.thread_pool.start(worker)
class DatasetTabWidget(QScrollArea):
    JSON = "JSON"
    PASCAL_VOC = "Pascal VOC"
    TENSORFLOW_OBJECT_DETECTION = "TensorFlow Object Detection"
    YOLO = "YOLO"

    def __init__(self, parent=None):
        super(DatasetTabWidget, self).__init__(parent)
        self.setCursor(QtCore.Qt.PointingHandCursor)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.data_grid = DatasetGridWidget()
        self.data_grid.new_dataset_action_signal.connect(
            self.btn_new_dataset_on_slot)
        self.data_grid.delete_dataset_action_signal.connect(
            self.btn_delete_dataset_on_slot)
        self.data_grid.refresh_dataset_action_signal.connect(
            self.refresh_dataset_action_slot)
        self.data_grid.edit_dataset_action_signal.connect(
            self.edit_dataset_action_slot)
        self.data_grid.open_dataset_action_signal.connect(
            self.open_dataset_action_slot)
        self.data_grid.download_anno_action_signal.connect(
            self.download_annot_action_slot)
        self.data_grid.import_anno_action_signal.connect(
            self.import_annot_action_slot)

        self.setWidget(self.data_grid)
        self.setWidgetResizable(True)
        self.thread_pool = QThreadPool()
        self.loading_dialog = QLoadingDialog()
        self._ds_dao = DatasetDao()
        self._labels_dao = LabelDao()
        self._annot_dao = AnnotaDao()
        self.load()

    @gui_exception
    def load(self):
        @work_exception
        def do_work():
            results = self._ds_dao.fetch_all_with_size()
            return results, None

        @gui_exception
        def done_work(result):
            data, error = result
            if error:
                raise error
            self.data_grid.data_source = data
            self.data_grid.bind()

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self.thread_pool.start(worker)

    def _close_tab(self, tab_class):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        for i in range(tab_widget_manager.count()):
            curr_tab_widget = tab_widget_manager.widget(i)
            if isinstance(curr_tab_widget, tab_class):
                tab_widget_manager.removeTab(i)

    @gui_exception
    def btn_new_dataset_on_slot(self):
        form = DatasetForm()
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self._ds_dao.save(vo)
            self.load()

    @gui_exception
    def btn_delete_dataset_on_slot(self, vo: DatasetVO):
        reply = QMessageBox.question(self, 'Confirmation', "Are you sure?",
                                     QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self._ds_dao.delete(vo.id)
            usr_folder = FileUtilities.get_usr_folder()
            ds_folder = os.path.join(usr_folder, vo.folder)
            FileUtilities.delete_folder(ds_folder)
            self.load()

    @gui_exception
    def edit_dataset_action_slot(self, vo: DatasetVO):
        form = DatasetForm(vo)
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self._ds_dao.save(vo)
            self.load()

    def refresh_dataset_action_slot(self, vo: DatasetVO):
        self.load()

    def open_dataset_action_slot(self, vo: DatasetVO):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        tab_widget = MediaTabWidget(vo)
        # self._close_tab(MediaTabWidget)
        for i in range(tab_widget_manager.count()):
            tab_widget_manager.removeTab(i)
        index = tab_widget_manager.addTab(tab_widget, vo.name)
        tab_widget_manager.setCurrentIndex(index)

    @gui_exception
    def download_annot_action_slot(self, vo: DatasetVO):
        menu = QMenu()
        menu.setCursor(QtCore.Qt.PointingHandCursor)
        labels_menu = QMenu("labels")
        menu.addMenu(labels_menu)
        boxes_menu = QMenu("boxes")
        menu.addMenu(boxes_menu)
        masks_menu = QMenu("masks")
        menu.addMenu(masks_menu)

        # labels menu actions
        labels_menu_export2csv_action = labels_menu.addAction(".csv")
        labels_menu_export2json_action = labels_menu.addAction(".json")

        #  boxes menu actions
        #boxes_menu_export2csv_action = boxes_menu.addAction(".csv (cv-studio)")
        boxes_menu_export2json_action = boxes_menu.addAction(
            ".json (cv-studio)")
        boxes_menu_export2pascal_action = boxes_menu.addAction(
            ".xml (pascal voc)")
        # boxes_menu.addSeparator()
        #boxes_menu.addAction(".xml (pascal voc)")
        # boxes_menu.addAction(".json (COCO)")
        # boxes_menu.addAction(".txt (YOLO)")

        # masks menu actions
        #polygons_menu_export2csv_action = masks_menu.addAction(".csv (cv-studio)")
        polygons_menu_export2json_action = masks_menu.addAction(
            ".json (cv-studio)")
        polygons_menu_export2png_action = masks_menu.addAction(".png")

        export_all_action = menu.addAction("all")

        action = menu.exec_(QCursor.pos())
        if action:
            # labels menu actions
            if action == labels_menu_export2csv_action:
                self.export_labels_annots(vo, "csv")
            elif action == labels_menu_export2json_action:
                self.export_labels_annots(vo, "json")

            # boxes menu actions
            # elif action == boxes_menu_export2csv_action:
            #     self.export_boxes_annots(vo, "csv")
            elif action == boxes_menu_export2json_action:
                self.export_boxes_annots(vo, "json")
            elif action == boxes_menu_export2pascal_action:
                self.export_boxes_annots2pascal(vo)
            # masks menu actions
            # elif action == polygons_menu_export2csv_action:
            #     self.export_polygons_annots(vo, "csv")
            elif action == polygons_menu_export2json_action:
                self.export_polygons_annots(vo, "json")
            elif action == polygons_menu_export2png_action:
                self.export_polygons_annots2png(vo)

            elif action == export_all_action:
                self.export_all_annot(vo)

    @gui_exception
    def import_annot_action_slot(self, dataset_vo: DatasetVO):
        menu = QMenu()
        menu.setCursor(QtCore.Qt.PointingHandCursor)
        menu.addAction(self.PASCAL_VOC)
        action = menu.exec_(QCursor.pos())
        if action:
            action_text = action.text()
            if action_text == self.PASCAL_VOC:
                colors = ColorUtilities.rainbow_gradient(1000)["hex"]
                files = GUIUtilities.select_files(
                    ".xml", "Select the annotations files")
                if len(files) > 0:

                    @work_exception
                    def do_work():
                        annotations = []
                        for xml_file in files:
                            tree = ET.parse(xml_file)
                            root = tree.getroot()
                            objects = root.findall('object')
                            image_path = root.find('path').text
                            image_vo = self._ds_dao.find_by_path(
                                dataset_vo.id, image_path)
                            if image_vo:
                                for roi in objects:
                                    label_name = roi.find('name').text
                                    label_name = label_name.title()
                                    label_vo = self._labels_dao.find_by_name(
                                        dataset_vo.id, label_name)
                                    if label_vo is None:
                                        label_vo = LabelVO()
                                        label_vo.name = label_name
                                        label_vo.dataset = dataset_vo.id
                                        label_vo.color = colors[random.randint(
                                            0, len(colors))]
                                        label_vo = self._labels_dao.save(
                                            label_vo)
                                    box = roi.find("bndbox")
                                    if box:
                                        x1 = int(box.find('xmin').text)
                                        y1 = int(box.find('ymin').text)
                                        x2 = int(box.find('xmax').text)
                                        y2 = int(box.find('ymax').text)
                                        box = AnnotaVO()
                                        box.label = label_vo.id
                                        box.entry = image_vo.id
                                        box.kind = "box"
                                        box.points = ",".join(
                                            map(str, [x1, y1, x2, y2]))
                                        annotations.append(box)
                        if len(annotations) > 0:
                            print(annotations)
                            self._annot_dao.save(dataset_vo.id, annotations)
                        return annotations, None

                    @gui_exception
                    def done_work(result):
                        data, error = result
                        if error:
                            raise error
                        if len(data) > 0:
                            GUIUtilities.show_info_message(
                                "Annotations imported successfully",
                                "Import annotations status")
                        else:
                            GUIUtilities.show_info_message(
                                "No annotations found",
                                "Import annotations status")

                    worker = Worker(do_work)
                    worker.signals.result.connect(done_work)
                    self.thread_pool.start(worker)

    @gui_exception
    def export_labels_annots(self, dataset_vo: DatasetVO, export_format):

        selected_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))
        if selected_folder:

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_labels(dataset_vo.id)
                return annotations, None

            @gui_exception
            def done_work(result):
                annots, err = result
                if err:
                    raise err
                if annots:
                    output_file = Path(selected_folder)\
                        .joinpath("annotations.{}".format(export_format))
                    if export_format == "csv":
                        df = pd.DataFrame(annots)
                        df.to_csv(str(output_file), index=False)
                    else:

                        def dumper(obj):
                            try:
                                return obj.toJSON()
                            except:
                                return obj.__dict__

                        with open(output_file, "w") as f:
                            f.write(
                                json.dumps(annots, default=dumper, indent=2))

                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)

    @gui_exception
    def export_boxes_annots(self, dataset_vo: DatasetVO, export_format):

        selected_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))
        if selected_folder:

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_boxes(dataset_vo.id)
                return annotations, None

            @gui_exception
            def done_work(result):
                annots, err = result
                if err:
                    raise err
                if annots:
                    output_file = Path(selected_folder) \
                        .joinpath("annotations.{}".format(export_format))
                    if export_format == "csv":
                        df = pd.DataFrame(annots)
                        df.to_csv(str(output_file), index=False)
                    else:

                        def dumper(obj):
                            try:
                                return obj.toJSON()
                            except:
                                return obj.__dict__

                        with open(output_file, "w") as f:
                            f.write(
                                json.dumps(annots, default=dumper, indent=2))
                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)

    @gui_exception
    def export_polygons_annots(self, dataset_vo: DatasetVO, export_format):

        selected_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))
        if selected_folder:

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_polygons(dataset_vo.id)
                return annotations, None

            @gui_exception
            def done_work(result):
                annots, err = result
                if err:
                    raise err
                if annots:
                    output_file = Path(selected_folder) \
                        .joinpath("annotations.{}".format(export_format))
                    if export_format == "csv":
                        df = pd.DataFrame(annots)
                        df.to_csv(str(output_file), index=False)
                    else:

                        def dumper(obj):
                            try:
                                return obj.toJSON()
                            except:
                                return obj.__dict__

                        with open(output_file, "w") as f:
                            f.write(
                                json.dumps(annots, default=dumper, indent=2))
                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)

    @gui_exception
    def export_boxes_annots2pascal(self, dataset_vo: DatasetVO):

        output_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))
        if output_folder:
            output_folder = Path(output_folder)

            @dask.delayed
            def export_xml(img_path, img_annotations):
                str_template = '''
                   <annotation>
                       <folder>${folder}</folder>
                       <filename>${filename}</filename>
                       <path>${path}</path>
                       <source>
                           <database>Unknown</database>
                       </source>
                       <size>
                           <width>${width}</width>
                           <height>${height}</height>
                           <depth>${depth}</depth>
                       </size>
                       <segmented>0</segmented>
                       % for i, region in enumerate(annotations):
                           <object>
                               <name>${region["name"]}</name>
                               <pose>Unspecified</pose>
                               <truncated>0</truncated>
                               <difficult>0</difficult>
                               <bndbox>
                                   <xmin>${region["xmin"]}</xmin>
                                   <ymin>${region["ymin"]}</ymin>
                                   <xmax>${region["xmax"]}</xmax>
                                   <ymax>${region["ymax"]}</ymax>
                               </bndbox>
                           </object>
                       % endfor
                   </annotation>
                   '''
                img_path = Path(img_path)
                img_name = img_path.name
                img_stem = img_path.stem
                img: Image = Image.open(str(img_path))
                w, h = img.size
                c = len(img.getbands())

                xml_str = Template(str_template).render(
                    path=img_path,
                    folder=str(img_path.parent),
                    filename=img_name,
                    width=w,
                    height=h,
                    depth=c,
                    annotations=img_annotations)

                output_file = output_folder.joinpath("{}.xml".format(img_stem))
                with open(output_file, 'w') as f:
                    f.write(xml_str)

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_boxes(dataset_vo.id)
                if annotations:
                    images = sorted(annotations, key=lambda ann: ann["image"])
                    images = itertools.groupby(images,
                                               key=lambda ann: ann["image"])
                    delayed_tasks = []
                    for img_path, img_annotations in images:
                        boxes = []
                        for annot in img_annotations:
                            points = list(
                                map(int, annot["annot_points"].split(",")))
                            box = dict()
                            box["name"] = annot["label_name"]
                            box["xmin"] = points[0]
                            box["ymin"] = points[1]
                            box["xmax"] = points[2]
                            box["ymax"] = points[3]
                            boxes.append(box)
                        if len(boxes) > 0:
                            delayed_tasks.append(export_xml(img_path, boxes))
                    dask.compute(*delayed_tasks)
                    return True, None
                else:
                    return False, None

            @gui_exception
            def done_work(result):
                success, err = result
                if err:
                    raise err
                if success:
                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)

    @gui_exception
    def export_polygons_annots2png(self, dataset_vo: DatasetVO):
        selected_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))

        if selected_folder:
            selected_folder = Path(selected_folder)

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_polygons(dataset_vo.id)
                if annotations:
                    color_palette = set(ann["label_color"]
                                        for ann in annotations)
                    color_palette = [
                        ColorUtilities.hex2RGB(c) for c in color_palette
                    ]
                    color_palette = np.asarray([[0, 0, 0]] + color_palette)
                    color_palette_flatten = color_palette.flatten()

                    labels_map = set(ann["label_name"] for ann in annotations)
                    labels_map = sorted(labels_map)

                    labels_map = {l: i + 1 for i, l in enumerate(labels_map)}
                    colors_map = {
                        l: color_palette.tolist()[i + 1]
                        for i, l in enumerate(labels_map)
                    }

                    images = sorted(annotations, key=lambda ann: ann["image"])
                    images = itertools.groupby(images,
                                               key=lambda ann: ann["image"])
                    for img_path, img_annotations in images:
                        image_name = Path(img_path).stem
                        image = Image.open(img_path).convert("RGB")
                        width, height = image.size
                        mask = Image.new("P", (width, height), 0)
                        mask.putpalette(color_palette_flatten.tolist())
                        for region in img_annotations:
                            label = region["label_name"]
                            points = region["annot_points"]
                            points = map(float, points.split(","))
                            points = list(
                                map(lambda pt: tuple(pt),
                                    more_itertools.chunked(points, 2)))
                            if len(points) > 0:
                                if label in labels_map:
                                    drawable_image = ImageDraw.Draw(mask)
                                    label_id = labels_map[label]
                                    drawable_image.polygon(points,
                                                           fill=label_id)
                                    del drawable_image
                        mask.save(
                            str(
                                selected_folder.joinpath(
                                    "{}.png".format(image_name))),
                            "PNG")  # export image

                    def dumper(obj):
                        try:
                            return obj.toJSON()
                        except:
                            return obj.__dict__

                    colors_map_file = selected_folder.joinpath(
                        "colors_map.json")
                    with open(str(colors_map_file), "w") as f:
                        f.write(
                            json.dumps(colors_map, default=dumper, indent=2))

                    labels_map_file = selected_folder.joinpath(
                        "labels_map.json")
                    with open(str(labels_map_file), "w") as f:
                        f.write(
                            json.dumps(labels_map, default=dumper, indent=2))

                return annotations, None

            @gui_exception
            def done_work(result):
                annotations, err = result
                if err:
                    raise err
                if annotations:
                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)

    @gui_exception
    def export_all_annot(self, dataset_vo: DatasetVO):
        selected_folder = str(
            QFileDialog.getExistingDirectory(None, "select the folder"))
        if selected_folder:

            @work_exception
            def do_work():
                annotations = self._annot_dao.fetch_all_by_dataset(
                    dataset_vo.id)
                return annotations, None

            @gui_exception
            def done_work(result):
                annots, err = result
                if err:
                    raise err
                if annots:
                    output_file = Path(selected_folder) \
                        .joinpath("annotations.json")

                    def dumper(obj):
                        try:
                            return obj.toJSON()
                        except:
                            return obj.__dict__

                    with open(output_file, "w") as f:
                        f.write(json.dumps(annots, default=dumper, indent=2))

                    GUIUtilities.show_info_message(
                        "Annotations exported successfully", "Done")
                else:
                    GUIUtilities.show_info_message(
                        "Not annotations found for the dataset {}".format(
                            dataset_vo.name), "Done")

            worker = Worker(do_work)
            worker.signals.result.connect(done_work)
            self.thread_pool.start(worker)
Beispiel #8
0
class DatasetTabWidget(QScrollArea):
    def __init__(self, parent=None):
        super(DatasetTabWidget, self).__init__(parent)
        self.setCursor(QtCore.Qt.PointingHandCursor)
        self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.data_grid = DatasetGridWidget()
        self.data_grid.new_dataset_action_signal.connect(
            self.btn_new_dataset_on_slot)
        self.data_grid.delete_dataset_action_signal.connect(
            self.btn_delete_dataset_on_slot)
        self.data_grid.refresh_dataset_action_signal.connect(
            self.refresh_dataset_action_slot)
        self.data_grid.edit_dataset_action_signal.connect(
            self.edit_dataset_action_slot)
        self.data_grid.open_dataset_action_signal.connect(
            self.open_dataset_action_slot)
        self.data_grid.download_anno_action_signal.connect(
            self.download_annot_action_slot)
        self.setWidget(self.data_grid)
        self.setWidgetResizable(True)
        self.thread_pool = QThreadPool()
        self.loading_dialog = QLoadingDialog()
        self.ds_dao = DatasetDao()
        self.annot_dao = AnnotaDao()
        self.load()

    @gui_exception
    def load(self):
        @work_exception
        def do_work():
            results = self.ds_dao.fetch_all_with_size()
            return results, None

        @gui_exception
        def done_work(result):
            data, error = result
            if error:
                raise error
            self.data_grid.data_source = data
            self.data_grid.bind()

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self.thread_pool.start(worker)

    def _close_tab(self, tab_class):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        for i in range(tab_widget_manager.count()):
            curr_tab_widget = tab_widget_manager.widget(i)
            if isinstance(curr_tab_widget, tab_class):
                tab_widget_manager.removeTab(i)

    @QtCore.pyqtSlot()
    @gui_exception
    def btn_new_dataset_on_slot(self):
        form = DatasetForm()
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self.ds_dao.save(vo)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    @gui_exception
    def btn_delete_dataset_on_slot(self, vo: DatasetVO):
        reply = QMessageBox.question(self, 'Confirmation', "Are you sure?",
                                     QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.ds_dao.delete(vo.id)
            usr_folder = FileUtilities.get_usr_folder()
            ds_folder = os.path.join(usr_folder, vo.folder)
            FileUtilities.delete_folder(ds_folder)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    @gui_exception
    def edit_dataset_action_slot(self, vo: DatasetVO):
        form = DatasetForm(vo)
        if form.exec_() == QDialog.Accepted:
            vo: DatasetVO = form.value
            self.ds_dao.save(vo)
            self.load()

    @QtCore.pyqtSlot(DatasetVO)
    def refresh_dataset_action_slot(self, vo: DatasetVO):
        self.load()

    @QtCore.pyqtSlot(DatasetVO)
    def open_dataset_action_slot(self, vo: DatasetVO):
        tab_widget_manager: QTabWidget = self.window().tab_widget_manager
        tab_widget = MediaTabWidget(vo)
        #self._close_tab(MediaTabWidget)
        for i in range(tab_widget_manager.count()):
            tab_widget_manager.removeTab(i)
        index = tab_widget_manager.addTab(tab_widget, vo.name)
        tab_widget_manager.setCurrentIndex(index)

    @gui_exception
    def download_annot_action_slot(self, vo: DatasetVO):
        @work_exception
        def do_work():
            results = self.annot_dao.fetch_all_by_dataset(vo.id)
            return results, None

        @gui_exception
        def done_work(result):
            data, error = result
            if error:
                raise error
            groups = itertools.groupby(data, lambda x: x["image"])
            annot_list = []
            for key, annotations in groups:
                image = ImageSchemeVO()
                image.path = key
                for annot_dict in list(annotations):
                    annot = AnnotSchemeVO()
                    annot.kind = annot_dict["annot_kind"]
                    annot.points = annot_dict["annot_points"]
                    annot.label_name = annot_dict["label_name"]
                    annot.label_color = annot_dict["label_color"]
                    image.regions.append(annot)
                annot_list.append(image)
            scheme = ImageScheme(many=True)
            options = QFileDialog.Options()
            options |= QFileDialog.DontUseNativeDialog
            default_file = os.path.join(os.path.expanduser('~'),
                                        "annotations.json")
            fileName, _ = QFileDialog.getSaveFileName(self,
                                                      "Export annotations",
                                                      default_file,
                                                      "Json Files (*.json)",
                                                      options=options)
            if fileName:
                with open(fileName, 'w') as f:
                    json.dump(scheme.dump(annot_list), f, indent=3)

        worker = Worker(do_work)
        worker.signals.result.connect(done_work)
        self.thread_pool.start(worker)