コード例 #1
0
ファイル: runner.py プロジェクト: DiamondLightSource/SuRVoS2
    def get_run_fields(self):
        """Gets the QGroupBox that contains the fields for starting SuRVoS.

        Returns:
            PyQt5.QWidgets.GroupBox: GroupBox with run fields.
        """
        self.run_button = QPushButton("Run SuRVoS")
        advanced_button = QRadioButton("Advanced")

        run_fields = QGroupBox("Run SuRVoS:")
        run_layout = QGridLayout()
        run_layout.addWidget(QLabel("Workspace Name:"), 0, 0)

        workspaces = os.listdir(CHROOT)
        self.workspaces_list = ComboBox()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)
        run_layout.addWidget(self.workspaces_list, 0, 0)

        self.ws_name_linedt_2 = QLineEdit(self.workspace_config["workspace_name"])
        self.ws_name_linedt_2.setAlignment(Qt.AlignLeft)
        run_layout.addWidget(self.ws_name_linedt_2, 0, 2)
        run_layout.addWidget(advanced_button, 1, 0)
        run_layout.addWidget(self.adv_run_fields, 2, 1)
        run_layout.addWidget(self.run_button, 3, 0, 1, 3)
        run_fields.setLayout(run_layout)

        advanced_button.toggled.connect(self.toggle_advanced)
        self.run_button.clicked.connect(self.run_clicked)

        return run_fields
コード例 #2
0
    def __init__(self, parent=None):
        super(SVMWidget, self).__init__(parent=parent)

        vbox = QtWidgets.QVBoxLayout()
        vbox.setContentsMargins(0, 0, 0, 0)
        self.setLayout(vbox)

        self.type_combo = ComboBox()
        self.type_combo.addCategory("Kernel Type:")
        self.type_combo.addItem("linear")
        self.type_combo.addItem("poly")
        self.type_combo.addItem("rbf")
        self.type_combo.addItem("sigmoid")
        vbox.addWidget(self.type_combo)

        self.penaltyc = LineEdit(default=1.0, parse=float)
        self.gamma = LineEdit(default=1.0, parse=float)

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("Penalty C:"),
                self.penaltyc,
                QtWidgets.QLabel("Gamma:"),
                self.gamma,
                stretch=[0, 1, 0, 1],
            ))
コード例 #3
0
 def __init__(self, parent=None):
     super().__init__(parent=parent)
     self.pipeline_combo = ComboBox()
     self.vbox = VBox(self, spacing=4)
     self.vbox.addWidget(self.pipeline_combo)
     self.pipeline_combo.currentIndexChanged.connect(self.add_pipeline)
     self.existing_pipelines = dict()
     self._populate_pipelines()
コード例 #4
0
 def _add_model_type(self):
     self.model_type = ComboBox()
     self.model_type.addItem(key="unet3d")
     self.model_type.addItem(key="fpn3d")
     widget = HWidgets("Model type:",
                       self.model_type,
                       Spacing(35),
                       stretch=0)
     self.add_row(widget)
コード例 #5
0
 def _add_projection_choice(self):
     self.projection_type = ComboBox()
     self.projection_type.addItem(key="None")
     self.projection_type.addItem(key="pca")
     self.projection_type.addItem(key="rbp")
     self.projection_type.addItem(key="rproj")
     self.projection_type.addItem(key="std")
     widget = HWidgets("Projection:",
                       self.projection_type,
                       Spacing(35),
                       stretch=0)
     self.add_row(widget)
コード例 #6
0
    def __init__(self, parent=None):
        super(EnsembleWidget, self).__init__(parent=parent)

        vbox = QtWidgets.QVBoxLayout()
        vbox.setContentsMargins(0, 0, 0, 0)
        self.setLayout(vbox)

        self.type_combo = ComboBox()
        self.type_combo.addCategory("Ensemble Type:")
        self.type_combo.addItem("Random Forest")
        self.type_combo.addItem("ExtraRandom Forest")
        self.type_combo.addItem("AdaBoost")
        self.type_combo.addItem("GradientBoosting")
        #self.type_combo.addItem("XGBoost")

        self.type_combo.currentIndexChanged.connect(self.on_ensemble_changed)
        vbox.addWidget(self.type_combo)

        self.ntrees = LineEdit(default=100, parse=int)
        self.depth = LineEdit(default=15, parse=int)
        self.lrate = LineEdit(default=1.0, parse=float)
        self.subsample = LineEdit(default=1.0, parse=float)

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("# Trees:"),
                self.ntrees,
                QtWidgets.QLabel("Max Depth:"),
                self.depth,
                stretch=[0, 1, 0, 1],
            ))

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("Learn Rate:"),
                self.lrate,
                QtWidgets.QLabel("Subsample:"),
                self.subsample,
                stretch=[0, 1, 0, 1],
            ))

        # self.btn_train_predict = PushButton('Train & Predict')
        # self.btn_train_predict.clicked.connect(self.on_train_predict_clicked)
        self.n_jobs = LineEdit(default=10, parse=int)
        vbox.addWidget(HWidgets("Num Jobs", self.n_jobs))
コード例 #7
0
    def __init__(self, *args, **kwargs):
        QtWidgets.QWidget.__init__(self, *args, **kwargs)
        self.slice_mode = False

        self.slider = Slider(value=0, vmax=cfg.slice_max - 1)
        self.slider.setMinimumWidth(150)
        self.slider.sliderReleased.connect(self._params_updated)

        button_refresh = QPushButton("Refresh", self)
        button_refresh.clicked.connect(self.button_refresh_clicked)

        button_transfer = QPushButton("Transfer Layer")
        button_transfer.clicked.connect(self.button_transfer_clicked)

        self.button_slicemode = QPushButton("Slice mode", self)
        self.button_slicemode.clicked.connect(self.button_slicemode_clicked)

        workspaces = os.listdir(DataModel.g.CHROOT)
        self.workspaces_list = ComboBox()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)
        workspaces_widget = HWidgets("Switch Workspaces:",
                                     self.workspaces_list)
        self.workspaces_list.setEditable(True)
        self.workspaces_list.activated[str].connect(self.workspaces_selected)

        self.hbox_layout0 = QtWidgets.QHBoxLayout()
        hbox_layout_ws = QtWidgets.QHBoxLayout()
        hbox_layout1 = QtWidgets.QHBoxLayout()

        self.hbox_layout0.addWidget(self.slider)
        self.slider.hide()

        hbox_layout_ws.addWidget(workspaces_widget)
        hbox_layout_ws.addWidget(button_refresh)

        hbox_layout1.addWidget(button_transfer)
        hbox_layout1.addWidget(self.button_slicemode)

        vbox = VBox(self, margin=(1, 1, 1, 1), spacing=2)
        vbox.addLayout(self.hbox_layout0)
        vbox.addLayout(hbox_layout1)
        vbox.addLayout(hbox_layout_ws)
コード例 #8
0
    def _add_classifier_choice(self):
        self.classifier_type = ComboBox()
        self.classifier_type.addItem(key="Ensemble")
        self.classifier_type.addItem(key="SVM")
        widget = HWidgets("Classifier:",
                          self.classifier_type,
                          Spacing(35),
                          stretch=0)

        self.classifier_type.currentIndexChanged.connect(
            self._on_classifier_changed)

        self.clf_container = QtWidgets.QWidget()
        clf_vbox = VBox(self, spacing=4)
        clf_vbox.setContentsMargins(0, 0, 0, 0)
        self.clf_container.setLayout(clf_vbox)

        self.add_row(widget)
        self.add_row(self.clf_container, max_height=500)
        self.clf_container.layout().addWidget(self.ensembles)
コード例 #9
0
    def setup_annotation_widgets(self):
        anno_group_box = QGroupBox("Annotations:")
        anno_box_layout = QGridLayout()
        # Labels
        anno_box_layout.addWidget(QLabel("Annotation"), 0, 0, 1, 2)
        anno_box_layout.addWidget(QLabel("File type"), 0, 2)
        # Annotations combo
        self.anno_source = AnnoComboBox()
        anno_box_layout.addWidget(self.anno_source, 1, 0, 1, 2)
        # File type combo
        self.anno_ftype_combo = ComboBox()
        self.add_filetypes_to_combo(self.anno_ftype_combo)
        anno_box_layout.addWidget(self.anno_ftype_combo, 1, 2)
        # Button
        self.anno_export_btn = IconButton("fa.save",
                                          "Export data",
                                          accent=True)
        self.anno_export_btn.clicked.connect(self.save_anno)
        anno_box_layout.addWidget(self.anno_export_btn, 1, 3)

        anno_group_box.setLayout(anno_box_layout)
        return anno_group_box
コード例 #10
0
    def setup_feature_widgets(self):
        feat_group_box = QGroupBox("Features:")
        feat_box_layout = QGridLayout()
        # Labels
        feat_box_layout.addWidget(QLabel("Feature"), 0, 0, 1, 2)
        feat_box_layout.addWidget(QLabel("File type"), 0, 2)
        # Features combo
        self.feat_source = FeatureComboBox()
        feat_box_layout.addWidget(self.feat_source, 1, 0, 1, 2)
        # File type combo
        self.feat_ftype_combo = ComboBox()
        self.add_filetypes_to_combo(self.feat_ftype_combo)
        feat_box_layout.addWidget(self.feat_ftype_combo, 1, 2)
        # Button
        self.feat_export_btn = IconButton("fa.save",
                                          "Export data",
                                          accent=True)
        self.feat_export_btn.clicked.connect(self.save_feature)
        feat_box_layout.addWidget(self.feat_export_btn, 1, 3)

        feat_group_box.setLayout(feat_box_layout)
        return feat_group_box
コード例 #11
0
    def setup_pipeline_widgets(self):
        pipe_group_box = QGroupBox("Pipeline output:")
        pipe_box_layout = QGridLayout()
        # Labels
        pipe_box_layout.addWidget(QLabel("Pipeline"), 0, 0, 1, 2)
        pipe_box_layout.addWidget(QLabel("File type"), 0, 2)
        # Pipeline combo
        self.pipe_source = SuperRegionSegmentComboBox()
        pipe_box_layout.addWidget(self.pipe_source, 1, 0, 1, 2)
        # File type combo
        self.pipe_ftype_combo = ComboBox()
        self.add_filetypes_to_combo(self.pipe_ftype_combo)
        pipe_box_layout.addWidget(self.pipe_ftype_combo, 1, 2)
        # Button
        self.pipe_export_btn = IconButton("fa.save",
                                          "Export data",
                                          accent=True)
        self.pipe_export_btn.clicked.connect(self.save_pipe)
        pipe_box_layout.addWidget(self.pipe_export_btn, 1, 3)

        pipe_group_box.setLayout(pipe_box_layout)
        return pipe_group_box
コード例 #12
0
class ButtonPanelWidget(QtWidgets.QWidget):
    clientEvent = Signal(object)

    def __init__(self, *args, **kwargs):
        QtWidgets.QWidget.__init__(self, *args, **kwargs)
        self.slice_mode = False

        self.slider = Slider(value=0, vmax=cfg.slice_max - 1)
        self.slider.setMinimumWidth(150)
        self.slider.sliderReleased.connect(self._params_updated)

        button_refresh = QPushButton("Refresh", self)
        button_refresh.clicked.connect(self.button_refresh_clicked)

        button_transfer = QPushButton("Transfer Layer")
        button_transfer.clicked.connect(self.button_transfer_clicked)

        self.button_slicemode = QPushButton("Slice mode", self)
        self.button_slicemode.clicked.connect(self.button_slicemode_clicked)

        workspaces = os.listdir(DataModel.g.CHROOT)
        self.workspaces_list = ComboBox()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)
        workspaces_widget = HWidgets("Switch Workspaces:",
                                     self.workspaces_list)
        self.workspaces_list.setEditable(True)
        self.workspaces_list.activated[str].connect(self.workspaces_selected)

        self.hbox_layout0 = QtWidgets.QHBoxLayout()
        hbox_layout_ws = QtWidgets.QHBoxLayout()
        hbox_layout1 = QtWidgets.QHBoxLayout()

        self.hbox_layout0.addWidget(self.slider)
        self.slider.hide()

        hbox_layout_ws.addWidget(workspaces_widget)
        hbox_layout_ws.addWidget(button_refresh)

        hbox_layout1.addWidget(button_transfer)
        hbox_layout1.addWidget(self.button_slicemode)

        vbox = VBox(self, margin=(1, 1, 1, 1), spacing=2)
        vbox.addLayout(self.hbox_layout0)
        vbox.addLayout(hbox_layout1)
        vbox.addLayout(hbox_layout_ws)

    def refresh_workspaces(self):
        workspaces = os.listdir(DataModel.g.CHROOT)
        self.workspaces_list.clear()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)
        self.slider.setMinimumWidth(cfg.base_dataset_shape[0])

    def workspaces_selected(self):
        selected_workspace = self.workspaces_list.value()
        self.workspaces_list.blockSignals(True)
        self.workspaces_list.select(selected_workspace)
        self.workspaces_list.blockSignals(False)

        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "set_workspace",
            "workspace": selected_workspace,
        })

    def sessions_selected(self):
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "set_session",
            "session": self.session_list.value(),
        })

    def button_setroi_clicked(self):
        roi_start = self.roi_start.value()
        roi_end = self.roi_end.value()
        roi = [
            roi_start[0],
            roi_start[1],
            roi_start[2],
            roi_end[0],
            roi_end[1],
            roi_end[2],
        ]
        self.refresh_workspaces()
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "make_roi_ws",
            "roi": roi
        })

    def button_refresh_clicked(self):
        self.refresh_workspaces()
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "empty_viewer",
            "value": None
        })
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "refresh",
            "value": None
        })

    def button_pause_save_clicked(self):
        if cfg.pause_save:
            self.button_pause_save.setText("Pause Saving to Server")
        else:
            self.button_pause_save.setText("Resume saving to Server")

        cfg.pause_save = not cfg.pause_save

    def button_transfer_clicked(self):
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "transfer_layer",
            "value": None
        })

    def button_slicemode_clicked(self):
        if self.slice_mode:
            self.slider.hide()
            self.button_slicemode.setText("Slice Mode")
        else:
            logger.info(f"Slice mode {cfg.slice_max}")

            self.slider.vmax = cfg.slice_max - 1
            self.slider.setRange(0, cfg.slice_max - 1)
            self.slider.show()
            self.button_slicemode.setText("Volume Mode")
        self.slice_mode = not self.slice_mode

        cfg.ppw.clientEvent.emit({
            "source": "button_slicemode",
            "data": "slice_mode",
        })

    def _params_updated(self):
        cfg.ppw.clientEvent.emit({
            "source": "slider",
            "data": "jump_to_slice",
            "frame": self.slider.value()
        })
コード例 #13
0
class PipelinesPlugin(Plugin):
    __icon__ = "fa.picture-o"
    __pname__ = "pipelines"
    __views__ = ["slice_viewer"]
    __tab__ = "pipelines"

    def __init__(self, parent=None):
        super().__init__(parent=parent)
        self.pipeline_combo = ComboBox()
        self.vbox = VBox(self, spacing=4)
        self.vbox.addWidget(self.pipeline_combo)
        self.pipeline_combo.currentIndexChanged.connect(self.add_pipeline)
        self.existing_pipelines = dict()
        self._populate_pipelines()

    def _populate_pipelines(self):
        self.pipeline_params = {}
        self.pipeline_combo.clear()
        self.pipeline_combo.addItem("Add segmentation")

        result = Launcher.g.run("pipelines", "available", workspace=True)

        if not result:
            params = {}
            params["category"] = "superregion"
            params["name"] = "s0"
            params["type"] = "superregion_segment"
            result = {}
            result[0] = params
            self.pipeline_params["superregion_segment"] = {
                "sr_params": {
                    "type": "sr2",
                }
            }
        else:
            all_categories = sorted(set(p["category"] for p in result))

            for i, category in enumerate(all_categories):
                self.pipeline_combo.addItem(category)
                self.pipeline_combo.model().item(i +
                                                 len(self.pipeline_params) +
                                                 1).setEnabled(False)

                for f in [p for p in result if p["category"] == category]:
                    self.pipeline_params[f["name"]] = f["params"]
                    self.pipeline_combo.addItem(f["name"])

    def add_pipeline(self, idx):
        if idx <= 0:
            return
        if self.pipeline_combo.itemText(idx) == "":
            return

        logger.debug(f"Adding pipeline {self.pipeline_combo.itemText(idx)}")

        pipeline_type = self.pipeline_combo.itemText(idx)
        self.pipeline_combo.setCurrentIndex(0)

        params = dict(pipeline_type=pipeline_type, workspace=True)
        result = Launcher.g.run("pipelines", "create", **params)

        if result:
            fid = result["id"]
            ftype = result["kind"]
            fname = result["name"]
            self._add_pipeline_widget(fid, ftype, fname, True)
            _PipelineNotifier.notify()

    def _add_pipeline_widget(self, fid, ftype, fname, expand=False):
        widget = PipelineCard(fid, ftype, fname, self.pipeline_params[ftype])
        widget.showContent(expand)
        self.vbox.addWidget(widget)
        self.existing_pipelines[fid] = widget
        return widget

    def clear(self):
        for pipeline in list(self.existing_pipelines.keys()):
            self.existing_pipelines.pop(pipeline).setParent(None)
        self.existing_pipelines = {}

    def setup(self):
        self._populate_pipelines()
        params = dict(workspace=DataModel.g.current_session + "@" +
                      DataModel.g.current_workspace)
        result = Launcher.g.run("pipelines", "existing", **params)
        logger.debug(f"Pipeline result {result}")

        if result:
            # Remove pipelines that no longer exist in the server
            for pipeline in list(self.existing_pipelines.keys()):
                if pipeline not in result:
                    self.existing_pipelines.pop(pipeline).setParent(None)

            # Populate with new pipelines if any
            for pipeline in sorted(result):
                if pipeline in self.existing_pipelines:
                    continue
                params = result[pipeline]
                logger.debug(f"Pipeline params {params}")
                fid = params.pop("id", pipeline)
                ftype = params.pop("kind")
                fname = params.pop("name", pipeline)
                widget = self._add_pipeline_widget(fid, ftype, fname)
                widget.update_params(params)
                self.existing_pipelines[fid] = widget
コード例 #14
0
class PipelineCard(Card):
    def __init__(self, fid, ftype, fname, fparams, parent=None):
        super().__init__(fname,
                         removable=True,
                         editable=True,
                         collapsible=True,
                         parent=parent)
        self.pipeline_id = fid
        self.pipeline_type = ftype
        self.pipeline_name = fname

        #from qtpy.QtWidgets import QProgressBar

        #self.pbar = QProgressBar(self)
        #self.add_row(self.pbar)

        self.params = fparams
        print(fparams)
        self.widgets = dict()

        if self.pipeline_type == "superregion_segment":
            logger.debug("Adding a superregion_segment pipeline")
            self._add_features_source()
            self._add_annotations_source()
            self._add_constrain_source()
            self._add_regions_source()

            self.ensembles = EnsembleWidget()
            self.ensembles.train_predict.connect(self.compute_pipeline)
            self.svm = SVMWidget()
            self.svm.predict.connect(self.compute_pipeline)

            self._add_classifier_choice()
            self._add_projection_choice()
            self._add_param("lam", type="FloatSlider", default=0.15)
            self._add_confidence_choice()

        elif self.pipeline_type == "rasterize_points":
            self._add_annotations_source()
            self._add_feature_source()
            self._add_objects_source()

        elif self.pipeline_type == "watershed":
            self._add_annotations_source()
            self._add_feature_source()

        elif self.pipeline_type == "predict_segmentation_fcn":
            self._add_annotations_source()
            self._add_feature_source()
            self._add_workflow_file()
            self._add_model_type()
            # self._add_patch_params()

        elif self.pipeline_type == "label_postprocess":
            self._add_annotations_source(label="Layer Over: ")
            self._add_annotations_source2(label="Layer Base: ")
            self.label_index = LineEdit(default=-1, parse=int)
            #widget = HWidgets("Selected label:", self.label_index, Spacing(35), stretch=1)
            #self.add_row(widget)
            #self.offset = LineEdit(default=-1, parse=int)
            #widget2 = HWidgets("Offset:", self.offset, Spacing(35), stretch=1)
            #self.add_row(widget2)

        elif self.pipeline_type == "cleaning":
            # self._add_objects_source()
            self._add_feature_source()
            self._add_annotations_source()

        elif self.pipeline_type == "train_2d_unet":
            self._add_annotations_source()
            self._add_feature_source()
            self._add_unet_2d_training_params()

        elif self.pipeline_type == "predict_2d_unet":
            self._add_annotations_source()
            self._add_feature_source()
            self._add_unet_2d_prediction_params()

        else:
            logger.debug(f"Unsupported pipeline type {self.pipeline_type}.")

        for pname, params in fparams.items():
            if pname not in ["src", "dst"]:
                self._add_param(pname, **params)

        self._add_compute_btn()
        self._add_view_btn()

    def _add_model_type(self):
        self.model_type = ComboBox()
        self.model_type.addItem(key="unet3d")
        self.model_type.addItem(key="fpn3d")
        widget = HWidgets("Model type:",
                          self.model_type,
                          Spacing(35),
                          stretch=0)
        self.add_row(widget)

    def _add_patch_params(self):
        self.patch_size = LineEdit3D(default=64, parse=int)
        self.add_row(
            HWidgets("Patch Size:", self.patch_size, Spacing(35), stretch=1))

    def _add_unet_2d_training_params(self):
        self.add_row(HWidgets("Training Parameters:", Spacing(35), stretch=1))
        self.cycles_frozen = LineEdit(default=8, parse=int)
        self.cycles_unfrozen = LineEdit(default=5, parse=int)
        self.add_row(
            HWidgets("No. Cycles Frozen:",
                     self.cycles_frozen,
                     "No. Cycles Unfrozen",
                     self.cycles_unfrozen,
                     stretch=1))

    def _add_unet_2d_prediction_params(self):
        self.model_file_line_edit = LineEdit(default="Filepath", parse=str)
        model_input_btn = PushButton("Select Model", accent=True)
        model_input_btn.clicked.connect(self.get_model_path)
        self.radio_group = QtWidgets.QButtonGroup()
        self.radio_group.setExclusive(True)
        single_pp_rb = QRadioButton("Single plane")
        single_pp_rb.setChecked(True)
        self.radio_group.addButton(single_pp_rb, 1)
        triple_pp_rb = QRadioButton("Three plane")
        self.radio_group.addButton(triple_pp_rb, 3)
        self.add_row(
            HWidgets(self.model_file_line_edit, model_input_btn, Spacing(35)))
        self.add_row(HWidgets("Prediction Parameters:", Spacing(35),
                              stretch=1))
        self.add_row(HWidgets(single_pp_rb, triple_pp_rb, stretch=1))

    def _add_workflow_file(self):
        self.filewidget = FileWidget(extensions="*.pt", save=False)
        self.add_row(self.filewidget)
        self.filewidget.path_updated.connect(self.load_data)

    def load_data(self, path):
        self.model_fullname = path
        print(f"Setting model fullname: {self.model_fullname}")

    def _add_view_btn(self):
        view_btn = PushButton("View", accent=True)
        view_btn.clicked.connect(self.view_pipeline)
        load_as_annotation_btn = PushButton("Load as annotation", accent=True)
        load_as_annotation_btn.clicked.connect(self.load_as_annotation)
        load_as_float_btn = PushButton("Load as image", accent=True)
        load_as_float_btn.clicked.connect(self.load_as_float)
        self.add_row(
            HWidgets(None, load_as_float_btn, load_as_annotation_btn, view_btn,
                     Spacing(35)))

    def _add_refine_choice(self):
        self.refine_checkbox = CheckBox(checked=True)
        self.add_row(
            HWidgets("MRF Refinement:",
                     self.refine_checkbox,
                     Spacing(35),
                     stretch=0))

    def _add_confidence_choice(self):
        self.confidence_checkbox = CheckBox(checked=False)
        self.add_row(
            HWidgets("Confidence Map as Feature:",
                     self.confidence_checkbox,
                     Spacing(35),
                     stretch=0))

    def _add_objects_source(self):
        self.objects_source = ObjectComboBox(full=True)
        self.objects_source.fill()
        self.objects_source.setMaximumWidth(250)

        widget = HWidgets("Objects:",
                          self.objects_source,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_classifier_choice(self):
        self.classifier_type = ComboBox()
        self.classifier_type.addItem(key="Ensemble")
        self.classifier_type.addItem(key="SVM")
        widget = HWidgets("Classifier:",
                          self.classifier_type,
                          Spacing(35),
                          stretch=0)

        self.classifier_type.currentIndexChanged.connect(
            self._on_classifier_changed)

        self.clf_container = QtWidgets.QWidget()
        clf_vbox = VBox(self, spacing=4)
        clf_vbox.setContentsMargins(0, 0, 0, 0)
        self.clf_container.setLayout(clf_vbox)

        self.add_row(widget)
        self.add_row(self.clf_container, max_height=500)
        self.clf_container.layout().addWidget(self.ensembles)

    def _on_classifier_changed(self, idx):
        if idx == 0:
            self.clf_container.layout().addWidget(self.ensembles)
            self.svm.setParent(None)
        elif idx == 1:
            self.clf_container.layout().addWidget(self.svm)
            self.ensembles.setParent(None)

    def _add_projection_choice(self):
        self.projection_type = ComboBox()
        self.projection_type.addItem(key="None")
        self.projection_type.addItem(key="pca")
        self.projection_type.addItem(key="rbp")
        self.projection_type.addItem(key="rproj")
        self.projection_type.addItem(key="std")
        widget = HWidgets("Projection:",
                          self.projection_type,
                          Spacing(35),
                          stretch=0)
        self.add_row(widget)

    def _add_feature_source(self):
        self.feature_source = FeatureComboBox()
        self.feature_source.fill()
        self.feature_source.setMaximumWidth(250)

        widget = HWidgets("Feature:",
                          self.feature_source,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_features_source(self):
        self.features_source = MultiSourceComboBox()
        self.features_source.fill()
        self.features_source.setMaximumWidth(250)
        cfg.pipelines_features_source = self.features_source
        widget = HWidgets("Features:",
                          self.features_source,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_constrain_source(self):
        print(self.annotations_source.value())
        self.constrain_mask_source = AnnotationComboBox(header=(None, "None"),
                                                        full=True)
        self.constrain_mask_source.fill()
        self.constrain_mask_source.setMaximumWidth(250)

        widget = HWidgets("Constrain mask:",
                          self.constrain_mask_source,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_annotations_source(self, label="Annotation"):
        self.annotations_source = LevelComboBox(full=True)
        self.annotations_source.fill()
        self.annotations_source.setMaximumWidth(250)

        widget = HWidgets(label,
                          self.annotations_source,
                          Spacing(35),
                          stretch=1)

        self.add_row(widget)

    def _add_annotations_source2(self, label="Annotation 2"):
        self.annotations_source2 = LevelComboBox(full=True)
        self.annotations_source2.fill()
        self.annotations_source2.setMaximumWidth(250)

        widget = HWidgets(label,
                          self.annotations_source2,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_pipelines_source(self):
        self.pipelines_source = PipelinesComboBox()
        self.pipelines_source.fill()
        self.pipelines_source.setMaximumWidth(250)
        widget = HWidgets("Segmentation:",
                          self.pipelines_source,
                          Spacing(35),
                          stretch=1)
        self.add_row(widget)

    def _add_regions_source(self):
        self.regions_source = RegionComboBox(full=True)  # SourceComboBox()
        self.regions_source.fill()
        self.regions_source.setMaximumWidth(250)

        widget = HWidgets("Superregions:",
                          self.regions_source,
                          Spacing(35),
                          stretch=1)
        cfg.pipelines_regions_source = self.regions_source
        self.add_row(widget)

    def _add_param(self, name, title=None, type="String", default=None):
        if type == "Int":
            p = LineEdit(default=0, parse=int)
        elif type == "FloatSlider":
            p = RealSlider(value=0.0, vmax=1, vmin=0)
            title = "MRF Refinement Amount:"
        elif type == "Float":
            p = LineEdit(default=0.0, parse=float)
            title = title
        elif type == "FloatOrVector":
            p = LineEdit3D(default=0, parse=float)
        elif type == "IntOrVector":
            p = LineEdit3D(default=0, parse=int)
        elif type == "SmartBoolean":
            p = CheckBox(checked=True)
        else:
            p = None

        if title is None:
            title = name

        if p:
            self.widgets[name] = p
            self.add_row(HWidgets(None, title, p, Spacing(35)))

    def _add_compute_btn(self):
        compute_btn = PushButton("Compute", accent=True)
        compute_btn.clicked.connect(self.compute_pipeline)
        self.add_row(HWidgets(None, compute_btn, Spacing(35)))

    def update_params(self, params):
        logger.debug(f"Pipeline update params {params}")
        for k, v in params.items():
            if k in self.widgets:
                self.widgets[k].setValue(v)
        if "anno_id" in params:
            if params["anno_id"] is not None:

                self.annotations_source.select(
                    os.path.join("annotations/", params["anno_id"]))
        if "object_id" in params:
            if params["object_id"] is not None:
                self.objects_source.select(
                    os.path.join("objects/", params["object_id"]))
        if "feature_id" in params:
            for source in params["feature_id"]:
                self.feature_source.select(os.path.join("features/", source))
        if "feature_ids" in params:
            for source in params["feature_ids"]:
                self.features_source.select(os.path.join("features/", source))
        if "region_id" in params:
            if params["region_id"] is not None:
                self.regions_source.select(
                    os.path.join("regions/", params["region_id"]))
        if "constrain_mask" in params:
            if (params["constrain_mask"] is not None
                    and params["constrain_mask"] != "None"):
                import ast

                constrain_mask_dict = ast.literal_eval(
                    params["constrain_mask"])
                print(constrain_mask_dict)

                constrain_mask_source = (constrain_mask_dict["level"] + ":" +
                                         str(constrain_mask_dict["idx"]))
                print(f"Constrain mask source {constrain_mask_source}")
                self.constrain_mask_source.select(constrain_mask_source)

    def card_deleted(self):
        params = dict(pipeline_id=self.pipeline_id, workspace=True)
        result = Launcher.g.run("pipelines", "remove", **params)
        if result["done"]:
            self.setParent(None)
            _PipelineNotifier.notify()

        cfg.ppw.clientEvent.emit({
            "source": "pipelines",
            "data": "remove_layer",
            "layer_name": self.pipeline_id,
        })

    def view_pipeline(self):
        logger.debug(f"View pipeline_id {self.pipeline_id}")
        with progress(total=2) as pbar:
            pbar.set_description("Viewing feature")
            pbar.update(1)
            if self.annotations_source:
                if self.annotations_source.value():
                    level_id = str(self.annotations_source.value().rsplit(
                        "/", 1)[-1])
                else:
                    level_id = '001_level'
                logger.debug(f"Assigning annotation level {level_id}")

                cfg.ppw.clientEvent.emit({
                    "source": "pipelines",
                    "data": "view_pipeline",
                    "pipeline_id": self.pipeline_id,
                    "level_id": level_id,
                })
            pbar.update(1)

    def get_model_path(self):
        workspace_path = os.path.join(DataModel.g.CHROOT,
                                      DataModel.g.current_workspace)
        self.model_path, _ = QtWidgets.QFileDialog.getOpenFileName(
            self, ("Select model"), workspace_path, ("Model files (*.zip)"))
        self.model_file_line_edit.setValue(self.model_path)

    def load_as_float(self):
        logger.debug(f"Loading prediction {self.pipeline_id} as float image.")

        # get pipeline output
        src = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines")
        with DatasetManager(src, out=None, dtype="uint32", fillvalue=0) as DM:
            src_arr = DM.sources[0][:]
        # create new float image
        params = dict(feature_type="raw", workspace=True)
        result = Launcher.g.run("features", "create", **params)

        if result:
            fid = result["id"]
            ftype = result["kind"]
            fname = result["name"]
            logger.debug(
                f"Created new object in workspace {fid}, {ftype}, {fname}")

            dst = DataModel.g.dataset_uri(fid, group="features")
            with DatasetManager(dst, out=dst, dtype="float32",
                                fillvalue=0) as DM:
                DM.out[:] = src_arr

            cfg.ppw.clientEvent.emit({
                "source": "workspace_gui",
                "data": "refresh",
                "value": None
            })

    def load_as_annotation(self):
        logger.debug(f"Loading prediction {self.pipeline_id} as annotation.")

        # get pipeline output
        src = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines")
        with DatasetManager(src, out=None, dtype="uint32", fillvalue=0) as DM:
            src_arr = DM.sources[0][:]
        label_values = np.unique(src_arr)

        # create new level
        params = dict(level=self.pipeline_id, workspace=True)
        result = Launcher.g.run("annotations", "add_level", workspace=True)

        # create a blank label for each unique value in the pipeline output array
        if result:
            level_id = result["id"]

            for v in label_values:
                params = dict(
                    level=level_id,
                    idx=int(v),
                    name=str(v),
                    color="#11FF11",
                    workspace=True,
                )
                label_result = Launcher.g.run("annotations", "add_label",
                                              **params)

            params = dict(
                level=str(self.annotations_source.value().rsplit("/", 1)[-1]),
                workspace=True,
            )
            anno_result = Launcher.g.run("annotations", "get_levels",
                                         **params)[0]

            params = dict(level=str(level_id), workspace=True)
            level_result = Launcher.g.run("annotations", "get_levels",
                                          **params)[0]

            try:
                # set the new level color mapping to the mapping from the pipeline
                for v in level_result["labels"].keys():
                    if v in anno_result["labels"]:
                        label_hex = anno_result["labels"][v]["color"]
                        label = dict(
                            idx=int(v),
                            name=str(v),
                            color=label_hex,
                        )
                        params = dict(level=result["id"], workspace=True)
                        label_result = Launcher.g.run("annotations",
                                                      "update_label", **params,
                                                      **label)
            except Exception as err:
                logger.debug(f"Exception {err}")

            fid = result["id"]
            ftype = result["kind"]
            fname = result["name"]
            logger.debug(
                f"Created new object in workspace {fid}, {ftype}, {fname}")

            # set levels array to pipeline output array
            dst = DataModel.g.dataset_uri(fid, group="annotations")
            with DatasetManager(dst, out=dst, dtype="uint32",
                                fillvalue=0) as DM:
                DM.out[:] = src_arr

            cfg.ppw.clientEvent.emit({
                "source": "workspace_gui",
                "data": "refresh",
                "value": None
            })

    def setup_params_superregion_segment(self, dst):
        feature_names_list = [
            n.rsplit("/", 1)[-1] for n in self.features_source.value()
        ]
        src_grp = None if self.annotations_source.currentIndex(
        ) == 0 else "pipelines"
        src = DataModel.g.dataset_uri(
            self.annotations_source.value().rsplit("/", 1)[-1],
            group="annotations",
        )
        all_params = dict(src=src, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace

        logger.info(f"Setting src to {self.annotations_source.value()} ")
        all_params["region_id"] = str(self.regions_source.value().rsplit(
            "/", 1)[-1])
        all_params["feature_ids"] = feature_names_list
        all_params["anno_id"] = str(self.annotations_source.value().rsplit(
            "/", 1)[-1])
        if self.constrain_mask_source.value() != None:
            all_params["constrain_mask"] = self.constrain_mask_source.value(
            )  # .rsplit("/", 1)[-1]
        else:
            all_params["constrain_mask"] = "None"

        all_params["dst"] = dst
        all_params["refine"] = self.widgets["refine"].value()
        all_params["lam"] = self.widgets["lam"].value()
        all_params["classifier_type"] = self.classifier_type.value()
        all_params["projection_type"] = self.projection_type.value()
        all_params["confidence"] = self.confidence_checkbox.value()

        if self.classifier_type.value() == "Ensemble":
            all_params["classifier_params"] = self.ensembles.get_params()
        else:
            all_params["classifier_params"] = self.svm.get_params()
        return all_params

    def setup_params_rasterize_points(self, dst):
        src = DataModel.g.dataset_uri(self.feature_source.value(),
                                      group="features")
        all_params = dict(src=src, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        # all_params["anno_id"] = str(
        #    self.annotations_source.value().rsplit("/", 1)[-1]
        # )
        all_params["feature_id"] = self.feature_source.value()
        all_params["object_id"] = str(self.objects_source.value())
        all_params["acwe"] = self.widgets["acwe"].value()
        # all_params["object_scale"] = self.widgets["object_scale"].value()
        # all_params["object_offset"] = self.widgets["object_offset"].value()
        all_params["dst"] = dst
        return all_params

    def setup_params_watershed(self, dst):
        src = DataModel.g.dataset_uri(self.feature_source.value(),
                                      group="features")
        all_params = dict(src=src, dst=dst, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        all_params["dst"] = self.pipeline_id
        all_params["anno_id"] = str(self.annotations_source.value().rsplit(
            "/", 1)[-1])
        return all_params

    def setup_params_predict_segmentation_fcn(self, dst):
        src = DataModel.g.dataset_uri(self.feature_source.value(),
                                      group="features")
        all_params = dict(src=src, dst=dst, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        all_params["anno_id"] = str(self.annotations_source.value().rsplit(
            "/", 1)[-1])
        all_params["feature_id"] = self.feature_source.value()
        all_params["model_fullname"] = self.model_fullname
        all_params["model_type"] = self.model_type.value()
        all_params["dst"] = self.pipeline_id
        return all_params

    def setup_params_label_postprocess(self, dst):
        all_params = dict(modal=True)
        all_params["workspace"] = DataModel.g.current_workspace

        print(self.annotations_source.value())

        if (self.annotations_source.value()):
            all_params["level_over"] = str(
                self.annotations_source.value().rsplit("/", 1)[-1])
        else:
            all_params["level_over"] = "None"
        all_params["level_base"] = str(self.annotations_source2.value().rsplit(
            "/", 1)[-1])
        all_params["dst"] = dst

        #all_params["selected_label"] = int(self.label_index.value())
        #all_params["offset"] = int(self.offset.value())
        all_params["selected_label"] = int(
            self.widgets["selected_label"].value())
        all_params["offset"] = int(self.widgets["offset"].value())
        return all_params

    def setup_params_cleaning(self, dst):
        all_params = dict(dst=dst, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        all_params["feature_id"] = str(self.feature_source.value())
        # all_params["object_id"] = str(self.objects_source.value())
        return all_params

    def setup_params_train_2d_unet(self, dst):
        src = DataModel.g.dataset_uri(self.feature_source.value(),
                                      group="features")
        all_params = dict(src=src, dst=dst, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        all_params["feature_id"] = str(self.feature_source.value())
        all_params["anno_id"] = str(self.annotations_source.value().rsplit(
            "/", 1)[-1])
        all_params["unet_train_params"] = dict(
            cyc_frozen=self.cycles_frozen.value(),
            cyc_unfrozen=self.cycles_unfrozen.value())
        return all_params

    def setup_params_predict_2d_unet(self, dst):
        src = DataModel.g.dataset_uri(self.feature_source.value(),
                                      group="features")
        all_params = dict(src=src, dst=dst, modal=True)
        all_params["workspace"] = DataModel.g.current_workspace
        all_params["feature_id"] = str(self.feature_source.value())
        all_params["anno_id"] = str(self.annotations_source.value().rsplit(
            "/", 1)[-1])
        all_params["model_path"] = str(self.model_file_line_edit.value())
        all_params["no_of_planes"] = self.radio_group.checkedId()
        return all_params

    def compute_pipeline(self):
        dst = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines")

        with progress(total=3) as pbar:
            pbar.set_description("Calculating pipeline")
            pbar.update(1)
            try:
                if self.pipeline_type == "superregion_segment":
                    all_params = self.setup_params_superregion_segment(dst)
                elif self.pipeline_type == "rasterize_points":
                    all_params = self.setup_params_rasterize_points(dst)
                elif self.pipeline_type == "watershed":
                    all_params = self.setup_params_watershed(dst)
                elif self.pipeline_type == "predict_segmentation_fcn":
                    all_params = self.setup_params_predict_segmentation_fcn(
                        dst)
                elif self.pipeline_type == "label_postprocess":
                    all_params = self.setup_params_label_postprocess(dst)
                elif self.pipeline_type == "cleaning":
                    all_params = self.setup_params_cleaning(dst)
                elif self.pipeline_type == "train_2d_unet":
                    all_params = self.setup_params_train_2d_unet(dst)
                elif self.pipeline_type == "predict_2d_unet":
                    all_params = self.setup_params_predict_2d_unet(dst)
                else:
                    logger.warning(
                        f"No action exists for pipeline: {self.pipeline_type}")

                all_params.update(
                    {k: v.value()
                     for k, v in self.widgets.items()})

                logger.info(
                    f"Computing pipelines {self.pipeline_type} {all_params}")
                try:
                    pbar.update(1)
                    result = Launcher.g.run("pipelines", self.pipeline_type,
                                            **all_params)
                    print(result)
                except Exception as err:
                    print(err)
                if result is not None:
                    pbar.update(1)

            except Exception as e:
                print(e)

    def card_title_edited(self, newtitle):
        params = dict(pipeline_id=self.pipeline_id,
                      new_name=newtitle,
                      workspace=True)
        result = Launcher.g.run("pipelines", "rename", **params)

        if result["done"]:
            _PipelineNotifier.notify()

        return result["done"]
コード例 #15
0
class EnsembleWidget(QtWidgets.QWidget):
    train_predict = Signal(dict)

    def __init__(self, parent=None):
        super(EnsembleWidget, self).__init__(parent=parent)

        vbox = QtWidgets.QVBoxLayout()
        vbox.setContentsMargins(0, 0, 0, 0)
        self.setLayout(vbox)

        self.type_combo = ComboBox()
        self.type_combo.addCategory("Ensemble Type:")
        self.type_combo.addItem("Random Forest")
        self.type_combo.addItem("ExtraRandom Forest")
        self.type_combo.addItem("AdaBoost")
        self.type_combo.addItem("GradientBoosting")
        #self.type_combo.addItem("XGBoost")

        self.type_combo.currentIndexChanged.connect(self.on_ensemble_changed)
        vbox.addWidget(self.type_combo)

        self.ntrees = LineEdit(default=100, parse=int)
        self.depth = LineEdit(default=15, parse=int)
        self.lrate = LineEdit(default=1.0, parse=float)
        self.subsample = LineEdit(default=1.0, parse=float)

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("# Trees:"),
                self.ntrees,
                QtWidgets.QLabel("Max Depth:"),
                self.depth,
                stretch=[0, 1, 0, 1],
            ))

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("Learn Rate:"),
                self.lrate,
                QtWidgets.QLabel("Subsample:"),
                self.subsample,
                stretch=[0, 1, 0, 1],
            ))

        # self.btn_train_predict = PushButton('Train & Predict')
        # self.btn_train_predict.clicked.connect(self.on_train_predict_clicked)
        self.n_jobs = LineEdit(default=10, parse=int)
        vbox.addWidget(HWidgets("Num Jobs", self.n_jobs))

    def on_ensemble_changed(self, idx):
        if idx == 2:
            self.ntrees.setDefault(50)
        else:
            self.ntrees.setDefault(100)

        if idx == 3:
            self.lrate.setDefault(0.1)
            self.depth.setDefault(3)
        else:
            self.lrate.setDefault(1.0)
            self.depth.setDefault(15)

    def on_train_predict_clicked(self):
        ttype = ["rf", "erf", "ada", "gbf", "xgb"]
        params = {
            "clf": "ensemble",
            "type": ttype[self.type_combo.currentIndex()],
            "n_estimators": self.ntrees.value(),
            "max_depth": self.depth.value(),
            "learning_rate": self.lrate.value(),
            "subsample": self.subsample.value(),
            "n_jobs": self.n_jobs.value(),
        }
        self.train_predict.emit(params)

    def get_params(self):
        ttype = ["rf", "erf", "ada", "gbf", "xgb"]
        if self.type_combo.currentIndex() - 1 == 0:
            current_index = 0
        else:
            current_index = self.type_combo.currentIndex() - 1
        logger.debug(f"Ensemble type_combo index: {current_index}")
        params = {
            "clf": "ensemble",
            "type": ttype[current_index],
            "n_estimators": self.ntrees.value(),
            "max_depth": self.depth.value(),
            "learning_rate": self.lrate.value(),
            "subsample": self.subsample.value(),
            "n_jobs": self.n_jobs.value(),
        }
        return params
コード例 #16
0
class SVMWidget(QtWidgets.QWidget):
    predict = Signal(dict)

    def __init__(self, parent=None):
        super(SVMWidget, self).__init__(parent=parent)

        vbox = QtWidgets.QVBoxLayout()
        vbox.setContentsMargins(0, 0, 0, 0)
        self.setLayout(vbox)

        self.type_combo = ComboBox()
        self.type_combo.addCategory("Kernel Type:")
        self.type_combo.addItem("linear")
        self.type_combo.addItem("poly")
        self.type_combo.addItem("rbf")
        self.type_combo.addItem("sigmoid")
        vbox.addWidget(self.type_combo)

        self.penaltyc = LineEdit(default=1.0, parse=float)
        self.gamma = LineEdit(default=1.0, parse=float)

        vbox.addWidget(
            HWidgets(
                QtWidgets.QLabel("Penalty C:"),
                self.penaltyc,
                QtWidgets.QLabel("Gamma:"),
                self.gamma,
                stretch=[0, 1, 0, 1],
            ))

    def on_predict_clicked(self):
        params = {
            "clf": "svm",
            "kernel": self.type_combo.currentText(),
            "C": self.penaltyc.value(),
            "gamma": self.gamma.value(),
        }

        self.predict.emit(params)

    def get_params(self):
        params = {
            "clf": "svm",
            "kernel": self.type_combo.currentText(),
            "C": self.penaltyc.value(),
            "gamma": self.gamma.value(),
            "type": self.type_combo.value(),
        }

        return params
コード例 #17
0
ファイル: server.py プロジェクト: DiamondLightSource/SuRVoS2
class ServerPlugin(Plugin):
    __icon__ = "fa.qrcode"
    __pname__ = "server"
    __views__ = ["slice_viewer"]
    __tab__ = "server"

    def __init__(self, parent=None):
        super().__init__(parent=parent)

        run_config = {
            "server_ip": "127.0.0.1",
            "server_port": "8134",
            "workspace_name": "test_hunt_d4b",
            "use_ssh": False,
            "ssh_host": "ws168.diamond.ac.uk",
            "ssh_port": "22",
        }

        workspace_config = {
            "dataset_name": "data",
            "datasets_dir": "/path/to/my/data/dir",
            "vol_fname": "myfile.h5",
            "workspace_name": "my_survos_workspace",
            "downsample_by": "1",
        }

        from survos2.server.config import cfg

        pipeline_config = dict(cfg)

        self.run_config = run_config
        self.workspace_config = workspace_config
        self.pipeline_config = pipeline_config

        self.server_process = None
        self.client_process = None

        self.layout = QVBoxLayout()
        tabwidget = QTabWidget()
        tab1 = QWidget()
        tab2 = QWidget()

        tabwidget.addTab(tab1, "Setup and Start Survos")
        self.create_workspace_button = QPushButton("Create workspace")

        tab1.layout = QVBoxLayout()
        tab1.setLayout(tab1.layout)
        chroot_fields = self.get_chroot_fields()
        tab1.layout.addWidget(chroot_fields)
        workspace_fields = self.get_workspace_fields()
        tab1.layout.addWidget(workspace_fields)

        self.setup_adv_run_fields()
        self.adv_run_fields.hide()

        run_fields = self.get_run_fields()
        tab1.layout.addWidget(run_fields)
        output_config_button = QPushButton("Save config")

        self.create_workspace_button.clicked.connect(
            self.create_workspace_clicked)
        output_config_button.clicked.connect(self.output_config_clicked)
        self.layout.addWidget(tabwidget)

        self.setGeometry(300, 300, 600, 400)
        self.setWindowTitle("SuRVoS Settings Editor")
        current_fpth = os.path.dirname(os.path.abspath(__file__))
        self.setWindowIcon(
            QIcon(os.path.join(current_fpth, "resources", "logo.png")))
        self.setLayout(self.layout)
        self.show()

    def get_chroot_fields(self):
        chroot_fields = QGroupBox("Set Main Directory for Storing Workspaces:")
        chroot_fields.setMaximumHeight(130)
        chroot_layout = QGridLayout()
        self.given_chroot_linedt = QLineEdit(CHROOT)
        chroot_layout.addWidget(self.given_chroot_linedt, 1, 0, 1, 2)
        set_chroot_button = QPushButton("Set Workspaces Root")
        chroot_layout.addWidget(set_chroot_button, 1, 2)
        chroot_fields.setLayout(chroot_layout)
        set_chroot_button.clicked.connect(self.set_chroot)
        return chroot_fields

    def get_workspace_fields(self):
        """Gets the QGroupBox that contains all the fields for setting up the workspace.

        Returns:
            PyQt5.QWidgets.GroupBox: GroupBox with workspace fields.
        """
        select_data_button = QPushButton("Select")
        workspace_fields = QGroupBox("Create New Workspace:")
        wf_layout = QGridLayout()
        wf_layout.addWidget(QLabel("Data File Path:"), 0, 0)
        current_data_path = Path(self.workspace_config["datasets_dir"],
                                 self.workspace_config["vol_fname"])
        self.data_filepth_linedt = QLineEdit(str(current_data_path))
        wf_layout.addWidget(self.data_filepth_linedt, 1, 0, 1, 2)
        wf_layout.addWidget(select_data_button, 1, 2)
        wf_layout.addWidget(QLabel("HDF5 Internal Data Path:"), 2, 0, 1, 1)
        ws_dataset_name = self.workspace_config["dataset_name"]
        internal_h5_path = (ws_dataset_name
                            if str(ws_dataset_name).startswith("/") else "/" +
                            ws_dataset_name)
        self.h5_intpth_linedt = QLineEdit(internal_h5_path)
        wf_layout.addWidget(self.h5_intpth_linedt, 2, 1, 1, 1)
        wf_layout.addWidget(QLabel("Workspace Name:"), 3, 0)
        self.ws_name_linedt_1 = QLineEdit(
            self.workspace_config["workspace_name"])
        wf_layout.addWidget(self.ws_name_linedt_1, 3, 1)
        wf_layout.addWidget(QLabel("Downsample Factor:"), 4, 0)
        self.downsample_spinner = QSpinBox()
        self.downsample_spinner.setRange(1, 10)
        self.downsample_spinner.setSpecialValueText("None")
        self.downsample_spinner.setMaximumWidth(60)
        self.downsample_spinner.setValue(
            int(self.workspace_config["downsample_by"]))
        wf_layout.addWidget(self.downsample_spinner, 4, 1, 1, 1)
        # ROI
        self.setup_roi_fields()
        wf_layout.addWidget(self.roi_fields, 4, 2, 1, 2)
        self.roi_fields.hide()

        wf_layout.addWidget(self.create_workspace_button, 5, 0, 1, 3)
        workspace_fields.setLayout(wf_layout)
        select_data_button.clicked.connect(self.launch_data_loader)
        return workspace_fields

    def setup_roi_fields(self):
        """Sets up the QGroupBox that displays the ROI dimensions, if selected."""
        self.roi_fields = QGroupBox("ROI:")
        roi_fields_layout = QHBoxLayout()
        # z
        roi_fields_layout.addWidget(QLabel("z:"), 0)
        self.zstart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.zstart_roi_val, 1)
        roi_fields_layout.addWidget(QLabel("-"), 2)
        self.zend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.zend_roi_val, 3)
        # y
        roi_fields_layout.addWidget(QLabel("y:"), 4)
        self.ystart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.ystart_roi_val, 5)
        roi_fields_layout.addWidget(QLabel("-"), 6)
        self.yend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.yend_roi_val, 7)
        # x
        roi_fields_layout.addWidget(QLabel("x:"), 8)
        self.xstart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.xstart_roi_val, 9)
        roi_fields_layout.addWidget(QLabel("-"), 10)
        self.xend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.xend_roi_val, 11)

        self.roi_fields.setLayout(roi_fields_layout)

    def setup_adv_run_fields(self):
        """Sets up the QGroupBox that displays the advanced optiona for starting SuRVoS2."""
        self.adv_run_fields = QGroupBox("Advanced Run Settings:")
        adv_run_layout = QGridLayout()
        adv_run_layout.addWidget(QLabel("Server IP Address:"), 0, 0)
        self.server_ip_linedt = QLineEdit(self.run_config["server_ip"])
        adv_run_layout.addWidget(self.server_ip_linedt, 0, 1)
        adv_run_layout.addWidget(QLabel("Server Port:"), 1, 0)
        self.server_port_linedt = QLineEdit(self.run_config["server_port"])
        adv_run_layout.addWidget(self.server_port_linedt, 1, 1)

        # SSH Info
        self.ssh_button = QRadioButton("Use SSH")
        self.ssh_button.setAutoExclusive(False)
        adv_run_layout.addWidget(self.ssh_button, 0, 2)
        ssh_flag = self.run_config.get("use_ssh", False)
        if ssh_flag:
            self.ssh_button.setChecked(True)
        self.ssh_button.toggled.connect(self.toggle_ssh)

        self.adv_ssh_fields = QGroupBox("SSH Settings:")
        adv_ssh_layout = QGridLayout()
        adv_ssh_layout.setColumnStretch(2, 2)
        ssh_host_label = QLabel("Host")
        self.ssh_host_linedt = QLineEdit(self.run_config.get("ssh_host", ""))
        adv_ssh_layout.addWidget(ssh_host_label, 0, 0)
        adv_ssh_layout.addWidget(self.ssh_host_linedt, 0, 1, 1, 2)
        ssh_user_label = QLabel("Username")
        self.ssh_username_linedt = QLineEdit(self.get_login_username())
        adv_ssh_layout.addWidget(ssh_user_label, 1, 0)
        adv_ssh_layout.addWidget(self.ssh_username_linedt, 1, 1, 1, 2)
        ssh_port_label = QLabel("Port")
        self.ssh_port_linedt = QLineEdit(self.run_config.get("ssh_port", ""))
        adv_ssh_layout.addWidget(ssh_port_label, 2, 0)
        adv_ssh_layout.addWidget(self.ssh_port_linedt, 2, 1, 1, 2)
        self.adv_ssh_fields.setLayout(adv_ssh_layout)
        #adv_run_layout.addWidget(self.adv_ssh_fields, 1, 2, 2, 5)

        self.adv_run_fields.setLayout(adv_run_layout)

    def get_run_fields(self):
        """Gets the QGroupBox that contains the fields for starting SuRVoS.

        Returns:
            PyQt5.QWidgets.GroupBox: GroupBox with run fields.
        """
        self.run_button = QPushButton("Start Server")
        self.stop_button = QPushButton("Stop Server")

        self.existing_button = QPushButton("Use Existing Server")

        advanced_button = QRadioButton("Advanced")
        run_fields = QGroupBox("Run SuRVoS:")
        run_layout = QGridLayout()

        workspaces = os.listdir(CHROOT)
        self.workspaces_list = ComboBox()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)

        run_layout.addWidget(QLabel("Workspace Name:"), 0, 0)
        self.ws_name_linedt_2 = QLineEdit(
            self.workspace_config["workspace_name"])
        self.ws_name_linedt_2.setAlignment(Qt.AlignLeft)
        self.workspaces_list.setLineEdit(self.ws_name_linedt_2)

        # run_layout.addWidget(self.ws_name_linedt_2, 0, 1)

        run_layout.addWidget(self.workspaces_list, 0, 1)
        run_layout.addWidget(advanced_button, 1, 0)
        run_layout.addWidget(self.adv_run_fields, 2, 1)
        run_layout.addWidget(self.run_button, 3, 0, 1, 3)
        run_layout.addWidget(self.stop_button, 4, 0, 1, 3)
        run_layout.addWidget(self.existing_button, 5, 0, 1, 3)
        run_fields.setLayout(run_layout)

        advanced_button.toggled.connect(self.toggle_advanced)
        self.run_button.clicked.connect(self.run_clicked)
        self.stop_button.clicked.connect(self.stop_clicked)
        self.existing_button.clicked.connect(self.existing_clicked)

        return run_fields

    def get_login_username(self):
        try:
            user = getpass.getuser()
        except Exception:
            user = ""
        return user

    def refresh_chroot(self):
        workspaces = os.listdir(DataModel.g.CHROOT)
        self.workspaces_list.clear()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)

    @pyqtSlot()
    def set_chroot(self):
        CHROOT = self.given_chroot_linedt.text()
        Config.update({"model": {"chroot": CHROOT}})
        logger.debug(f"Setting CHROOT to {CHROOT}")
        DataModel.g.CHROOT = CHROOT
        self.refresh_chroot()

    @pyqtSlot()
    def launch_data_loader(self):
        """Load the dialog box widget to select data with data preview window and ROI selection."""
        path = None
        int_h5_pth = None
        dialog = LoadDataDialog(self)
        result = dialog.exec_()
        self.roi_limits = None
        if result == QDialog.Accepted:
            path = dialog.winput.path.text()
            int_h5_pth = dialog.int_h5_pth.text()
            down_factor = dialog.downsample_spinner.value()
        if path and int_h5_pth:
            self.data_filepth_linedt.setText(path)
            self.h5_intpth_linedt.setText(int_h5_pth)
            self.downsample_spinner.setValue(down_factor)
            if dialog.roi_changed:
                self.roi_limits = tuple(map(str, dialog.get_roi_limits()))
                self.roi_fields.show()
                self.update_roi_fields_from_dialog()
            else:
                self.roi_fields.hide()

    def update_roi_fields_from_dialog(self):
        """Updates the ROI fields in the main window."""
        x_start, x_end, y_start, y_end, z_start, z_end = self.roi_limits
        self.xstart_roi_val.setText(x_start)
        self.xend_roi_val.setText(x_end)
        self.ystart_roi_val.setText(y_start)
        self.yend_roi_val.setText(y_end)
        self.zstart_roi_val.setText(z_start)
        self.zend_roi_val.setText(z_end)

    @pyqtSlot()
    def toggle_advanced(self):
        """Controls displaying/hiding the advanced run fields on radio button toggle."""
        rbutton = self.sender()
        if rbutton.isChecked():
            self.adv_run_fields.show()
        else:
            self.adv_run_fields.hide()

    @pyqtSlot()
    def toggle_ssh(self):
        """Controls displaying/hiding the SSH fields on radio button toggle."""
        rbutton = self.sender()
        if rbutton.isChecked():
            self.adv_ssh_fields.show()
        else:
            self.adv_ssh_fields.hide()

    @pyqtSlot()
    def create_workspace_clicked(self):
        """Performs checks and coordinates workspace creation on button press."""
        logger.debug("Creating workspace: ")
        # Set the path to the data file
        vol_path = Path(self.data_filepth_linedt.text())
        if not vol_path.is_file():
            err_str = f"No data file exists at {vol_path}!"
            logger.error(err_str)
            self.button_feedback_response(err_str,
                                          self.create_workspace_button,
                                          "maroon")
        else:
            self.workspace_config["datasets_dir"] = str(vol_path.parent)
            self.workspace_config["vol_fname"] = str(vol_path.name)
            dataset_name = self.h5_intpth_linedt.text()
            self.workspace_config["dataset_name"] = str(dataset_name).strip(
                "/")
            # Set the workspace name
            ws_name = self.ws_name_linedt_1.text()
            self.workspace_config["workspace_name"] = ws_name
            # Set the downsample factor
            ds_factor = self.downsample_spinner.value()
            self.workspace_config["downsample_by"] = ds_factor
            # Set the ROI limits if they exist
            if self.roi_limits:
                self.workspace_config["roi_limits"] = self.roi_limits
            try:
                response = init_ws(self.workspace_config)
                _, error = response
                if not error:
                    self.button_feedback_response(
                        "Workspace created sucessfully",
                        self.create_workspace_button,
                        "green",
                    )
                    # Update the workspace name in the 'Run' section
                    self.ws_name_linedt_2.setText(self.ws_name_linedt_1.text())
            except WorkspaceException as e:
                logger.exception(e)
                self.button_feedback_response(str(e),
                                              self.create_workspace_button,
                                              "maroon")
            self.refresh_chroot()

    def button_feedback_response(self, message, button, colour_str, timeout=2):
        """Changes button colour and displays feedback message for a limited time period.

        Args:
            message (str): Message to display in button.
            button (PyQt5.QWidgets.QBushButton): The button to manipulate.
            colour_str (str): The standard CSS colour string or hex code describing the colour to change the button to.
        """
        timeout *= 1000
        msg_old = button.text()
        col_old = button.palette().button().color
        txt_col_old = button.palette().buttonText().color
        button.setText(message)
        button.setStyleSheet(f"background-color: {colour_str}; color: white")
        timer = QTimer()
        timer.singleShot(
            timeout,
            lambda: self.reset_button(button, msg_old, col_old, txt_col_old))

    @pyqtSlot()
    def reset_button(self, button, msg_old, col_old, txt_col_old):
        """Sets a button back to its original display settings.

        Args:
            button (PyQt5.QWidgets.QBushButton): The button to manipulate.
            msg_old (str): Message to display in button.
            col_old (str): The standard CSS colour string or hex code describing the colour to change the button to.
            txt_col_old (str): The standard CSS colour string or hex code describing the colour to change the button text to.
        """
        button.setStyleSheet(f"background-color: {col_old().name()}")
        button.setStyleSheet(f"color: {txt_col_old().name()}")
        button.setText(msg_old)
        button.update()

    @pyqtSlot()
    def output_config_clicked(self):
        """Outputs pipeline config YAML file on button click."""
        out_fname = "pipeline_cfg.yml"
        logger.debug(f"Outputting pipeline config: {out_fname}")
        with open(out_fname, "w") as outfile:
            yaml.dump(self.pipeline_config,
                      outfile,
                      default_flow_style=False,
                      sort_keys=False)

    def get_ssh_params(self):
        ssh_host = self.ssh_host_linedt.text()
        ssh_user = self.ssh_username_linedt.text()
        ssh_port = int(self.ssh_port_linedt.text())
        return ssh_host, ssh_user, ssh_port

    def start_server_over_ssh(self):
        params = self.get_ssh_params()
        if not all(params):
            logger.error(
                "Not all SSH parameters given! Not connecting to SSH.")
            pass
        ssh_host, ssh_user, ssh_port = params
        # Pop up dialog to ask for password
        text, ok = QInputDialog.getText(None, "Login",
                                        f"Password for {ssh_user}@{ssh_host}",
                                        QLineEdit.Password)
        if ok and text:
            self.ssh_worker = SSHWorker(params, text, self.run_config)
            self.ssh_thread = QThread(self)
            self.ssh_worker.moveToThread(self.ssh_thread)
            self.ssh_worker.button_message_signal.connect(
                self.send_msg_to_run_button)
            self.ssh_worker.error_signal.connect(self.on_ssh_error)
            self.ssh_worker.finished.connect(self.start_client)
            self.ssh_worker.update_ip_linedt_signal.connect(
                self.update_ip_linedt)
            self.ssh_thread.started.connect(
                self.ssh_worker.start_server_over_ssh)
            self.ssh_thread.start()

    def closeEvent(self, event):
        reply = QMessageBox.question(
            self,
            "Quit",
            "Are you sure you want to quit? "
            "The server will be stopped.",
            QMessageBox.Yes | QMessageBox.No,
            QMessageBox.No,
        )
        if reply == QMessageBox.Yes:
            event.accept()
        else:
            event.ignore()

    @pyqtSlot()
    def on_ssh_error(self):
        self.ssh_error = True

    @pyqtSlot(str)
    def update_ip_linedt(self, ip):
        self.server_ip_linedt.setText(ip)

    @pyqtSlot(list)
    def send_msg_to_run_button(self, param_list):
        self.button_feedback_response(param_list[0], self.run_button,
                                      param_list[1], param_list[2])

    @pyqtSlot()
    def stop_clicked(self):
        logger.debug("Stopping server")
        if self.server_process is not None:
            self.server_process.kill()

    @pyqtSlot()
    def run_clicked(self):
        """Starts SuRVoS2 server and client as subprocesses when 'Run' button pressed.

        Raises:
            Exception: If survos.py not found.
        """
        with progress(total=3) as pbar:
            pbar.set_description("Starting server...")
            pbar.update(1)

        self.ssh_error = (
            False  # Flag which will be set to True if there is an SSH error
        )
        command_dir = os.path.abspath(os.path.dirname(__file__))  # os.getcwd()

        # Set current dir to survos root
        from pathlib import Path

        command_dir = Path(
            command_dir).absolute().parent.parent.parent.resolve()
        os.chdir(command_dir)

        self.script_fullname = os.path.join(command_dir, "survos.py")
        if not os.path.isfile(self.script_fullname):
            raise Exception("{}: Script not found".format(
                self.script_fullname))
        # Retrieve the parameters from the fields TODO: Put some error checking in
        self.run_config["workspace_name"] = self.ws_name_linedt_2.text()
        self.run_config["server_port"] = self.server_port_linedt.text()
        # Temporary measure to check whether the workspace exists or not
        full_ws_path = os.path.join(Config["model.chroot"],
                                    self.run_config["workspace_name"])
        if not os.path.isdir(full_ws_path):
            logger.error(
                f"No workspace can be found at {full_ws_path}, Not starting SuRVoS."
            )
            self.button_feedback_response(
                f"Workspace {self.run_config['workspace_name']} does not appear to exist!",
                self.run_button,
                "maroon",
            )
            return
        pbar.update(1)
        # Try some fancy SSH stuff here
        if self.ssh_button.isChecked():
            self.start_server_over_ssh()
        else:
            self.server_process = subprocess.Popen([
                "python",
                self.script_fullname,
                "start_server",
                self.run_config["workspace_name"],
                self.run_config["server_port"],
                DataModel.g.CHROOT,
            ])
            try:
                outs, errs = self.server_process.communicate(timeout=10)
                print(f"OUTS: {outs, errs}")
            except subprocess.TimeoutExpired:
                pass

            # self.start_client()
            logger.info(f"setting remote: {self.server_port_linedt.text()}")
            remote_ip_port = "127.0.0.1:" + self.server_port_linedt.text()
            logger.info(f"setting remote: {remote_ip_port}")
            resp = Launcher.g.set_remote(remote_ip_port)
            logger.info(f"Response from server to setting remote: {resp}")

            cfg.ppw.clientEvent.emit({
                "source": "server_tab",
                "data": "set_workspace",
                "workspace": self.ws_name_linedt_2.text(),
            })
            cfg.ppw.clientEvent.emit({
                "source": "panel_gui",
                "data": "refresh",
                "value": None
            })
            #cfg.ppw.clientEvent.emit({'data' : 'view_feature', 'feature_id' : '001_raw'})
        pbar.update(1)

    @pyqtSlot()
    def existing_clicked(self):
        ssh_ip = self.server_ip_linedt.text()
        remote_ip_port = ssh_ip + ":" + self.server_port_linedt.text()
        logger.info(f"setting remote: {remote_ip_port}")
        resp = Launcher.g.set_remote(remote_ip_port)
        logger.info(f"Response from server to setting remote: {resp}")

        cfg.ppw.clientEvent.emit({
            "source": "server_tab",
            "data": "set_workspace",
            "workspace": self.ws_name_linedt_2.text(),
        })
        cfg.ppw.clientEvent.emit({
            "source": "panel_gui",
            "data": "refresh",
            "value": None
        })

    def start_client(self):
        if not self.ssh_error:
            self.button_feedback_response("Starting Client.", self.run_button,
                                          "green", 7)
            self.run_config["server_ip"] = self.server_ip_linedt.text()
            self.client_process = subprocess.Popen([
                "python",
                self.script_fullname,
                "nu_gui",
                self.run_config["workspace_name"],
                str(self.run_config["server_ip"]) + ":" +
                str(self.run_config["server_port"]),
            ])
コード例 #18
0
ファイル: runner.py プロジェクト: DiamondLightSource/SuRVoS2
class FrontEndRunner(QWidget):
    """Main FrontEnd Runner window for creating workspace and starting SuRVoS2."""

    def __init__(self, run_config, workspace_config, pipeline_config, *args, **kwargs):
        super().__init__()

        self.run_config = run_config
        self.workspace_config = workspace_config
        self.pipeline_config = pipeline_config

        self.server_process = None
        self.client_process = None

        pipeline_config_ptree = self.init_ptree(self.pipeline_config, name="Pipeline")

        self.layout = QVBoxLayout()
        tabwidget = QTabWidget()
        tab1 = QWidget()
        tab2 = QWidget()

        tabwidget.addTab(tab1, "Setup and Start Survos")
        tabwidget.addTab(tab2, "Pipeline")

        self.create_workspace_button = QPushButton("Create workspace")

        tab1.layout = QVBoxLayout()
        tab1.setLayout(tab1.layout)

        chroot_fields = self.get_chroot_fields()
        tab1.layout.addWidget(chroot_fields)

        self.setup_adv_run_fields()
        self.adv_run_fields.hide()

        run_fields = self.get_run_fields()
        tab1.layout.addWidget(run_fields)

        output_config_button = QPushButton("Save config")
        tab2.layout = QVBoxLayout()
        tab2.setLayout(tab2.layout)
        tab2.layout.addWidget(pipeline_config_ptree)
        tab2.layout.addWidget(output_config_button)

        self.create_workspace_button.clicked.connect(self.create_workspace_clicked)
        output_config_button.clicked.connect(self.output_config_clicked)
        self.layout.addWidget(tabwidget)

        self.setGeometry(300, 300, 600, 400)
        self.setWindowTitle("SuRVoS Settings Editor")
        current_fpth = os.path.dirname(os.path.abspath(__file__))
        self.setWindowIcon(QIcon(os.path.join(current_fpth, "resources", "logo.png")))
        self.setLayout(self.layout)
        self.show()

    def get_workspace_fields(self):
        """Gets the QGroupBox that contains all the fields for setting up the workspace.

        Returns:
            PyQt5.QWidgets.GroupBox: GroupBox with workspace fields.
        """

        select_data_button = QPushButton("Select")
        workspace_fields = QGroupBox("Create Workspace:")
        wf_layout = QGridLayout()
        wf_layout.addWidget(QLabel("Data File Path:"), 0, 0)
        current_data_path = Path(
            self.workspace_config["datasets_dir"], self.workspace_config["vol_fname"]
        )
        self.data_filepth_linedt = QLineEdit(str(current_data_path))
        wf_layout.addWidget(self.data_filepth_linedt, 1, 0, 1, 2)
        wf_layout.addWidget(select_data_button, 1, 2)
        wf_layout.addWidget(QLabel("HDF5 Internal Data Path:"), 2, 0, 1, 1)
        ws_dataset_name = self.workspace_config["dataset_name"]
        internal_h5_path = (
            ws_dataset_name
            if str(ws_dataset_name).startswith("/")
            else "/" + ws_dataset_name
        )
        self.h5_intpth_linedt = QLineEdit(internal_h5_path)
        wf_layout.addWidget(self.h5_intpth_linedt, 2, 1, 1, 1)
        wf_layout.addWidget(QLabel("Workspace Name:"), 3, 0)
        self.ws_name_linedt_1 = QLineEdit(self.workspace_config["workspace_name"])
        wf_layout.addWidget(self.ws_name_linedt_1, 3, 1)
        wf_layout.addWidget(QLabel("Downsample Factor:"), 4, 0)
        self.downsample_spinner = QSpinBox()
        self.downsample_spinner.setRange(1, 10)
        self.downsample_spinner.setSpecialValueText("None")
        self.downsample_spinner.setMaximumWidth(60)
        self.downsample_spinner.setValue(int(self.workspace_config["downsample_by"]))
        wf_layout.addWidget(self.downsample_spinner, 4, 1, 1, 1)
        # ROI
        self.setup_roi_fields()
        wf_layout.addWidget(self.roi_fields, 4, 2, 1, 2)
        self.roi_fields.hide()

        wf_layout.addWidget(self.create_workspace_button, 5, 0, 1, 3)
        workspace_fields.setLayout(wf_layout)
        select_data_button.clicked.connect(self.launch_data_loader)
        return workspace_fields

    def setup_roi_fields(self):
        """Sets up the QGroupBox that displays the ROI dimensions, if selected."""
        self.roi_fields = QGroupBox("ROI:")
        roi_fields_layout = QHBoxLayout()
        # z
        roi_fields_layout.addWidget(QLabel("z:"), 0)
        self.zstart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.zstart_roi_val, 1)
        roi_fields_layout.addWidget(QLabel("-"), 2)
        self.zend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.zend_roi_val, 3)
        # y
        roi_fields_layout.addWidget(QLabel("y:"), 4)
        self.ystart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.ystart_roi_val, 5)
        roi_fields_layout.addWidget(QLabel("-"), 6)
        self.yend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.yend_roi_val, 7)
        # x
        roi_fields_layout.addWidget(QLabel("x:"), 8)
        self.xstart_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.xstart_roi_val, 9)
        roi_fields_layout.addWidget(QLabel("-"), 10)
        self.xend_roi_val = QLabel("0")
        roi_fields_layout.addWidget(self.xend_roi_val, 11)

        self.roi_fields.setLayout(roi_fields_layout)

    def setup_adv_run_fields(self):
        """Sets up the QGroupBox that displays the advanced optiona for starting SuRVoS2."""
        self.adv_run_fields = QGroupBox("Advanced Run Settings:")
        adv_run_layout = QGridLayout()
        adv_run_layout.addWidget(QLabel("Server IP Address:"), 0, 0)
        self.server_ip_linedt = QLineEdit(self.run_config["server_ip"])
        adv_run_layout.addWidget(self.server_ip_linedt, 0, 1)
        adv_run_layout.addWidget(QLabel("Server Port:"), 1, 0)
        self.server_port_linedt = QLineEdit(self.run_config["server_port"])
        adv_run_layout.addWidget(self.server_port_linedt, 1, 1)
        # SSH Info
        self.ssh_button = QRadioButton("Use SSH")
        self.ssh_button.setAutoExclusive(False)
        adv_run_layout.addWidget(self.ssh_button, 0, 2)
        ssh_flag = self.run_config.get("use_ssh", False)
        if ssh_flag:
            self.ssh_button.setChecked(True)
        self.ssh_button.toggled.connect(self.toggle_ssh)

        self.adv_ssh_fields = QGroupBox("SSH Settings:")
        adv_ssh_layout = QGridLayout()
        adv_ssh_layout.setColumnStretch(2, 2)
        ssh_host_label = QLabel("Host")
        self.ssh_host_linedt = QLineEdit(self.run_config.get("ssh_host", ""))
        adv_ssh_layout.addWidget(ssh_host_label, 0, 0)
        adv_ssh_layout.addWidget(self.ssh_host_linedt, 0, 1, 1, 2)
        ssh_user_label = QLabel("Username")
        self.ssh_username_linedt = QLineEdit(self.get_login_username())
        adv_ssh_layout.addWidget(ssh_user_label, 1, 0)
        adv_ssh_layout.addWidget(self.ssh_username_linedt, 1, 1, 1, 2)
        ssh_port_label = QLabel("Port")
        self.ssh_port_linedt = QLineEdit(self.run_config.get("ssh_port", ""))
        adv_ssh_layout.addWidget(ssh_port_label, 2, 0)
        adv_ssh_layout.addWidget(self.ssh_port_linedt, 2, 1, 1, 2)
        self.adv_ssh_fields.setLayout(adv_ssh_layout)
        adv_run_layout.addWidget(self.adv_ssh_fields, 1, 2, 2, 5)
        self.adv_run_fields.setLayout(adv_run_layout)

    def get_run_fields(self):
        """Gets the QGroupBox that contains the fields for starting SuRVoS.

        Returns:
            PyQt5.QWidgets.GroupBox: GroupBox with run fields.
        """
        self.run_button = QPushButton("Run SuRVoS")
        advanced_button = QRadioButton("Advanced")

        run_fields = QGroupBox("Run SuRVoS:")
        run_layout = QGridLayout()
        run_layout.addWidget(QLabel("Workspace Name:"), 0, 0)

        workspaces = os.listdir(CHROOT)
        self.workspaces_list = ComboBox()
        for s in workspaces:
            self.workspaces_list.addItem(key=s)
        run_layout.addWidget(self.workspaces_list, 0, 0)

        self.ws_name_linedt_2 = QLineEdit(self.workspace_config["workspace_name"])
        self.ws_name_linedt_2.setAlignment(Qt.AlignLeft)
        run_layout.addWidget(self.ws_name_linedt_2, 0, 2)
        run_layout.addWidget(advanced_button, 1, 0)
        run_layout.addWidget(self.adv_run_fields, 2, 1)
        run_layout.addWidget(self.run_button, 3, 0, 1, 3)
        run_fields.setLayout(run_layout)

        advanced_button.toggled.connect(self.toggle_advanced)
        self.run_button.clicked.connect(self.run_clicked)

        return run_fields

    def get_login_username(self):
        try:
            user = getpass.getuser()
        except Exception:
            user = ""
        return user

    @pyqtSlot()
    def launch_data_loader(self):
        """Load the dialog box widget to select data with data preview window and ROI selection."""
        path = None
        int_h5_pth = None
        dialog = LoadDataDialog(self)
        result = dialog.exec_()
        self.roi_limits = None
        if result == QDialog.Accepted:
            path = dialog.winput.path.text()
            int_h5_pth = dialog.int_h5_pth.text()
            down_factor = dialog.downsample_spinner.value()
        if path and int_h5_pth:
            self.data_filepth_linedt.setText(path)
            self.h5_intpth_linedt.setText(int_h5_pth)
            self.downsample_spinner.setValue(down_factor)
            if dialog.roi_changed:
                self.roi_limits = tuple(map(str, dialog.get_roi_limits()))
                self.roi_fields.show()
                self.update_roi_fields_from_dialog()
            else:
                self.roi_fields.hide()

    def update_roi_fields_from_dialog(self):
        """Updates the ROI fields in the main window."""
        x_start, x_end, y_start, y_end, z_start, z_end = self.roi_limits
        self.xstart_roi_val.setText(x_start)
        self.xend_roi_val.setText(x_end)
        self.ystart_roi_val.setText(y_start)
        self.yend_roi_val.setText(y_end)
        self.zstart_roi_val.setText(z_start)
        self.zend_roi_val.setText(z_end)

    @pyqtSlot()
    def toggle_advanced(self):
        """Controls displaying/hiding the advanced run fields on radio button toggle."""
        rbutton = self.sender()
        if rbutton.isChecked():
            self.adv_run_fields.show()
        else:
            self.adv_run_fields.hide()

    @pyqtSlot()
    def toggle_ssh(self):
        """Controls displaying/hiding the SSH fields on radio button toggle."""
        rbutton = self.sender()
        if rbutton.isChecked():
            self.adv_ssh_fields.show()
        else:
            self.adv_ssh_fields.hide()

    def setup_ptree_params(self, p, config_dict):
        def parameter_tree_change(param, changes):
            for param, change, data in changes:
                path = p.childPath(param)

                if path is not None:
                    childName = ".".join(path)
                else:
                    childName = param.name()

                sibs = param.parent().children()

                config_dict[path[-1]] = data

        p.sigTreeStateChanged.connect(parameter_tree_change)

        def valueChanging(param, value):
            pass

        for child in p.children():
            child.sigValueChanging.connect(valueChanging)

            for ch2 in child.children():
                ch2.sigValueChanging.connect(valueChanging)

        return p

    def dict_to_params(self, param_dict, name="Group"):
        ptree_param_dicts = []
        ctr = 0
        for key in param_dict.keys():
            entry = param_dict[key]

            if type(entry) == str:
                d = {"name": key, "type": "str", "value": entry}
            elif type(entry) == int:
                d = {"name": key, "type": "int", "value": entry}
            elif type(entry) == list:
                d = {"name": key, "type": "list", "values": entry}
            elif type(entry) == float:
                d = {"name": key, "type": "float", "value": entry}
            elif type(entry) == bool:
                d = {"name": key, "type": "bool", "value": entry}
            elif type(entry) == dict:
                d = self.dict_to_params(entry, name="Subgroup")[0]
                d["name"] = key + "_" + str(ctr)
                ctr += 1
            else:
                print(f"Can't handle type {type(entry)}")

            ptree_param_dicts.append(d)

        ptree_init = [{"name": name, "type": "group", "children": ptree_param_dicts}]

        return ptree_init

    def init_ptree(self, config_dict, name="Group"):
        ptree_init = self.dict_to_params(config_dict, name)
        parameters = Parameter.create(
            name="ptree_init", type="group", children=ptree_init
        )
        params = self.setup_ptree_params(parameters, config_dict)
        ptree = ParameterTree()
        ptree.setParameters(params, showTop=False)

        return ptree

    @pyqtSlot()
    def create_workspace_clicked(self):
        """Performs checks and coordinates workspace creation on button press."""
        logger.debug("Creating workspace: ")
        # Set the path to the data file
        vol_path = Path(self.data_filepth_linedt.text())
        if not vol_path.is_file():
            err_str = f"No data file exists at {vol_path}!"
            logger.error(err_str)
            self.button_feedback_response(
                err_str, self.create_workspace_button, "maroon"
            )
        else:
            self.workspace_config["datasets_dir"] = str(vol_path.parent)
            self.workspace_config["vol_fname"] = str(vol_path.name)
            dataset_name = self.h5_intpth_linedt.text()
            self.workspace_config["dataset_name"] = str(dataset_name).strip("/")
            # Set the workspace name
            ws_name = self.ws_name_linedt_1.text()
            self.workspace_config["workspace_name"] = ws_name
            # Set the downsample factor
            ds_factor = self.downsample_spinner.value()
            self.workspace_config["downsample_by"] = ds_factor
            # Set the ROI limits if they exist
            if self.roi_limits:
                self.workspace_config["roi_limits"] = self.roi_limits
            try:
                response = init_ws(self.workspace_config)
                _, error = response
                if not error:
                    self.button_feedback_response(
                        "Workspace created sucessfully",
                        self.create_workspace_button,
                        "green",
                    )
                    # Update the workspace name in the 'Run' section
                    self.ws_name_linedt_2.setText(self.ws_name_linedt_1.text())
            except WorkspaceException as e:
                logger.exception(e)
                self.button_feedback_response(
                    str(e), self.create_workspace_button, "maroon"
                )

    def button_feedback_response(self, message, button, colour_str, timeout=2):
        """Changes button colour and displays feedback message for a limited time period.

        Args:
            message (str): Message to display in button.
            button (PyQt5.QWidgets.QBushButton): The button to manipulate.
            colour_str (str): The standard CSS colour string or hex code describing the colour to change the button to.
        """
        timeout *= 1000
        msg_old = button.text()
        col_old = button.palette().button().color
        txt_col_old = button.palette().buttonText().color
        button.setText(message)
        button.setStyleSheet(f"background-color: {colour_str}; color: white")
        timer = QTimer()
        timer.singleShot(
            timeout, lambda: self.reset_button(button, msg_old, col_old, txt_col_old)
        )

    @pyqtSlot()
    def reset_button(self, button, msg_old, col_old, txt_col_old):
        """Sets a button back to its original display settings.

        Args:
            button (PyQt5.QWidgets.QBushButton): The button to manipulate.
            msg_old (str): Message to display in button.
            col_old (str): The standard CSS colour string or hex code describing the colour to change the button to.
            txt_col_old (str): The standard CSS colour string or hex code describing the colour to change the button text to.
        """
        button.setStyleSheet(f"background-color: {col_old().name()}")
        button.setStyleSheet(f"color: {txt_col_old().name()}")
        button.setText(msg_old)
        button.update()

    @pyqtSlot()
    def output_config_clicked(self):
        """Outputs pipeline config YAML file on button click."""
        out_fname = "pipeline_cfg.yml"
        logger.debug(f"Outputting pipeline config: {out_fname}")
        with open(out_fname, "w") as outfile:
            yaml.dump(
                self.pipeline_config, outfile, default_flow_style=False, sort_keys=False
            )

    def get_ssh_params(self):
        ssh_host = self.ssh_host_linedt.text()
        ssh_user = self.ssh_username_linedt.text()
        ssh_port = int(self.ssh_port_linedt.text())
        return ssh_host, ssh_user, ssh_port

    def start_server_over_ssh(self):
        params = self.get_ssh_params()
        if not all(params):
            logger.error("Not all SSH parameters given! Not connecting to SSH.")
        ssh_host, ssh_user, ssh_port = params
        # Pop up dialog to ask for password
        text, ok = QInputDialog.getText(
            None, "Login", f"Password for {ssh_user}@{ssh_host}", QLineEdit.Password
        )
        if ok and text:
            self.ssh_worker = SSHWorker(params, text, self.run_config)
            self.ssh_thread = QThread(self)
            self.ssh_worker.moveToThread(self.ssh_thread)
            self.ssh_worker.button_message_signal.connect(self.send_msg_to_run_button)
            self.ssh_worker.error_signal.connect(self.on_ssh_error)
            self.ssh_worker.finished.connect(self.start_client)
            self.ssh_worker.update_ip_linedt_signal.connect(self.update_ip_linedt)
            self.ssh_thread.started.connect(self.ssh_worker.start_server_over_ssh)
            self.ssh_thread.start()

    def closeEvent(self, event):
        reply = QMessageBox.question(
            self,
            "Quit",
            "Are you sure you want to quit? " "The server will be stopped.",
            QMessageBox.Yes | QMessageBox.No,
            QMessageBox.No,
        )
        if reply == QMessageBox.Yes:
            event.accept()
        else:
            event.ignore()

    @pyqtSlot()
    def on_ssh_error(self):
        self.ssh_error = True

    @pyqtSlot(str)
    def update_ip_linedt(self, ip):
        self.server_ip_linedt.setText(ip)

    @pyqtSlot(list)
    def send_msg_to_run_button(self, param_list):
        self.button_feedback_response(
            param_list[0], self.run_button, param_list[1], param_list[2]
        )

    @pyqtSlot()
    def run_clicked(self):
        """Starts SuRVoS2 server and client as subprocesses when 'Run' button pressed.

        Raises:
            Exception: If survos.py not found.
        """
        self.ssh_error = (
            False  # Flag which will be set to True if there is an SSH error
        )
        command_dir = os.getcwd()
        self.script_fullname = os.path.join(command_dir, "survos.py")
        if not os.path.isfile(self.script_fullname):
            raise Exception("{}: Script not found".format(self.script_fullname))
        # Retrieve the parameters from the fields TODO: Put some error checking in
        self.run_config["workspace_name"] = self.ws_name_linedt_2.text()
        self.run_config["server_port"] = self.server_port_linedt.text()
        # Temporary measure to check whether the workspace exists or not
        full_ws_path = os.path.join(
            Config["model.chroot"], self.run_config["workspace_name"]
        )
        if not os.path.isdir(full_ws_path):
            logger.error(
                f"No workspace can be found at {full_ws_path}, Not starting SuRVoS."
            )
            self.button_feedback_response(
                f"Workspace {self.run_config['workspace_name']} does not appear to exist!",
                self.run_button,
                "maroon",
            )
            return
        # Try some fancy SSH stuff here
        if self.ssh_button.isChecked():
            self.start_server_over_ssh()
        else:
            self.server_process = subprocess.Popen(
                [
                    "python",
                    self.script_fullname,
                    "start_server",
                    self.run_config["workspace_name"],
                    self.run_config["server_port"],
                ]
            )
            try:
                outs, errs = self.server_process.communicate(timeout=10)
                print(f"OUTS: {outs, errs}")
            except subprocess.TimeoutExpired:
                pass

            self.start_client()

    def start_client(self):
        if not self.ssh_error:
            self.button_feedback_response(
                "Starting Client.", self.run_button, "green", 7
            )
            self.run_config["server_ip"] = self.server_ip_linedt.text()
            self.client_process = subprocess.Popen(
                [
                    "python",
                    self.script_fullname,
                    "nu_gui",
                    self.run_config["workspace_name"],
                    str(self.run_config["server_ip"])
                    + ":"
                    + str(self.run_config["server_port"]),
                ]
            )