class SVMWidget(QtWidgets.QWidget): predict = Signal(dict) def __init__(self, parent=None): super(SVMWidget, self).__init__(parent=parent) vbox = QtWidgets.QVBoxLayout() vbox.setContentsMargins(0, 0, 0, 0) self.setLayout(vbox) self.type_combo = ComboBox() self.type_combo.addCategory("Kernel Type:") self.type_combo.addItem("linear") self.type_combo.addItem("poly") self.type_combo.addItem("rbf") self.type_combo.addItem("sigmoid") vbox.addWidget(self.type_combo) self.penaltyc = LineEdit(default=1.0, parse=float) self.gamma = LineEdit(default=1.0, parse=float) vbox.addWidget( HWidgets( QtWidgets.QLabel("Penalty C:"), self.penaltyc, QtWidgets.QLabel("Gamma:"), self.gamma, stretch=[0, 1, 0, 1], )) def on_predict_clicked(self): params = { "clf": "svm", "kernel": self.type_combo.currentText(), "C": self.penaltyc.value(), "gamma": self.gamma.value(), } self.predict.emit(params) def get_params(self): params = { "clf": "svm", "kernel": self.type_combo.currentText(), "C": self.penaltyc.value(), "gamma": self.gamma.value(), "type": self.type_combo.value(), } return params
class ServerPlugin(Plugin): __icon__ = "fa.qrcode" __pname__ = "server" __views__ = ["slice_viewer"] __tab__ = "server" def __init__(self, parent=None): super().__init__(parent=parent) run_config = { "server_ip": "127.0.0.1", "server_port": "8134", "workspace_name": "test_hunt_d4b", "use_ssh": False, "ssh_host": "ws168.diamond.ac.uk", "ssh_port": "22", } workspace_config = { "dataset_name": "data", "datasets_dir": "/path/to/my/data/dir", "vol_fname": "myfile.h5", "workspace_name": "my_survos_workspace", "downsample_by": "1", } from survos2.server.config import cfg pipeline_config = dict(cfg) self.run_config = run_config self.workspace_config = workspace_config self.pipeline_config = pipeline_config self.server_process = None self.client_process = None self.layout = QVBoxLayout() tabwidget = QTabWidget() tab1 = QWidget() tab2 = QWidget() tabwidget.addTab(tab1, "Setup and Start Survos") self.create_workspace_button = QPushButton("Create workspace") tab1.layout = QVBoxLayout() tab1.setLayout(tab1.layout) chroot_fields = self.get_chroot_fields() tab1.layout.addWidget(chroot_fields) workspace_fields = self.get_workspace_fields() tab1.layout.addWidget(workspace_fields) self.setup_adv_run_fields() self.adv_run_fields.hide() run_fields = self.get_run_fields() tab1.layout.addWidget(run_fields) output_config_button = QPushButton("Save config") self.create_workspace_button.clicked.connect( self.create_workspace_clicked) output_config_button.clicked.connect(self.output_config_clicked) self.layout.addWidget(tabwidget) self.setGeometry(300, 300, 600, 400) self.setWindowTitle("SuRVoS Settings Editor") current_fpth = os.path.dirname(os.path.abspath(__file__)) self.setWindowIcon( QIcon(os.path.join(current_fpth, "resources", "logo.png"))) self.setLayout(self.layout) self.show() def get_chroot_fields(self): chroot_fields = QGroupBox("Set Main Directory for Storing Workspaces:") chroot_fields.setMaximumHeight(130) chroot_layout = QGridLayout() self.given_chroot_linedt = QLineEdit(CHROOT) chroot_layout.addWidget(self.given_chroot_linedt, 1, 0, 1, 2) set_chroot_button = QPushButton("Set Workspaces Root") chroot_layout.addWidget(set_chroot_button, 1, 2) chroot_fields.setLayout(chroot_layout) set_chroot_button.clicked.connect(self.set_chroot) return chroot_fields def get_workspace_fields(self): """Gets the QGroupBox that contains all the fields for setting up the workspace. Returns: PyQt5.QWidgets.GroupBox: GroupBox with workspace fields. """ select_data_button = QPushButton("Select") workspace_fields = QGroupBox("Create New Workspace:") wf_layout = QGridLayout() wf_layout.addWidget(QLabel("Data File Path:"), 0, 0) current_data_path = Path(self.workspace_config["datasets_dir"], self.workspace_config["vol_fname"]) self.data_filepth_linedt = QLineEdit(str(current_data_path)) wf_layout.addWidget(self.data_filepth_linedt, 1, 0, 1, 2) wf_layout.addWidget(select_data_button, 1, 2) wf_layout.addWidget(QLabel("HDF5 Internal Data Path:"), 2, 0, 1, 1) ws_dataset_name = self.workspace_config["dataset_name"] internal_h5_path = (ws_dataset_name if str(ws_dataset_name).startswith("/") else "/" + ws_dataset_name) self.h5_intpth_linedt = QLineEdit(internal_h5_path) wf_layout.addWidget(self.h5_intpth_linedt, 2, 1, 1, 1) wf_layout.addWidget(QLabel("Workspace Name:"), 3, 0) self.ws_name_linedt_1 = QLineEdit( self.workspace_config["workspace_name"]) wf_layout.addWidget(self.ws_name_linedt_1, 3, 1) wf_layout.addWidget(QLabel("Downsample Factor:"), 4, 0) self.downsample_spinner = QSpinBox() self.downsample_spinner.setRange(1, 10) self.downsample_spinner.setSpecialValueText("None") self.downsample_spinner.setMaximumWidth(60) self.downsample_spinner.setValue( int(self.workspace_config["downsample_by"])) wf_layout.addWidget(self.downsample_spinner, 4, 1, 1, 1) # ROI self.setup_roi_fields() wf_layout.addWidget(self.roi_fields, 4, 2, 1, 2) self.roi_fields.hide() wf_layout.addWidget(self.create_workspace_button, 5, 0, 1, 3) workspace_fields.setLayout(wf_layout) select_data_button.clicked.connect(self.launch_data_loader) return workspace_fields def setup_roi_fields(self): """Sets up the QGroupBox that displays the ROI dimensions, if selected.""" self.roi_fields = QGroupBox("ROI:") roi_fields_layout = QHBoxLayout() # z roi_fields_layout.addWidget(QLabel("z:"), 0) self.zstart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.zstart_roi_val, 1) roi_fields_layout.addWidget(QLabel("-"), 2) self.zend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.zend_roi_val, 3) # y roi_fields_layout.addWidget(QLabel("y:"), 4) self.ystart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.ystart_roi_val, 5) roi_fields_layout.addWidget(QLabel("-"), 6) self.yend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.yend_roi_val, 7) # x roi_fields_layout.addWidget(QLabel("x:"), 8) self.xstart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.xstart_roi_val, 9) roi_fields_layout.addWidget(QLabel("-"), 10) self.xend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.xend_roi_val, 11) self.roi_fields.setLayout(roi_fields_layout) def setup_adv_run_fields(self): """Sets up the QGroupBox that displays the advanced optiona for starting SuRVoS2.""" self.adv_run_fields = QGroupBox("Advanced Run Settings:") adv_run_layout = QGridLayout() adv_run_layout.addWidget(QLabel("Server IP Address:"), 0, 0) self.server_ip_linedt = QLineEdit(self.run_config["server_ip"]) adv_run_layout.addWidget(self.server_ip_linedt, 0, 1) adv_run_layout.addWidget(QLabel("Server Port:"), 1, 0) self.server_port_linedt = QLineEdit(self.run_config["server_port"]) adv_run_layout.addWidget(self.server_port_linedt, 1, 1) # SSH Info self.ssh_button = QRadioButton("Use SSH") self.ssh_button.setAutoExclusive(False) adv_run_layout.addWidget(self.ssh_button, 0, 2) ssh_flag = self.run_config.get("use_ssh", False) if ssh_flag: self.ssh_button.setChecked(True) self.ssh_button.toggled.connect(self.toggle_ssh) self.adv_ssh_fields = QGroupBox("SSH Settings:") adv_ssh_layout = QGridLayout() adv_ssh_layout.setColumnStretch(2, 2) ssh_host_label = QLabel("Host") self.ssh_host_linedt = QLineEdit(self.run_config.get("ssh_host", "")) adv_ssh_layout.addWidget(ssh_host_label, 0, 0) adv_ssh_layout.addWidget(self.ssh_host_linedt, 0, 1, 1, 2) ssh_user_label = QLabel("Username") self.ssh_username_linedt = QLineEdit(self.get_login_username()) adv_ssh_layout.addWidget(ssh_user_label, 1, 0) adv_ssh_layout.addWidget(self.ssh_username_linedt, 1, 1, 1, 2) ssh_port_label = QLabel("Port") self.ssh_port_linedt = QLineEdit(self.run_config.get("ssh_port", "")) adv_ssh_layout.addWidget(ssh_port_label, 2, 0) adv_ssh_layout.addWidget(self.ssh_port_linedt, 2, 1, 1, 2) self.adv_ssh_fields.setLayout(adv_ssh_layout) #adv_run_layout.addWidget(self.adv_ssh_fields, 1, 2, 2, 5) self.adv_run_fields.setLayout(adv_run_layout) def get_run_fields(self): """Gets the QGroupBox that contains the fields for starting SuRVoS. Returns: PyQt5.QWidgets.GroupBox: GroupBox with run fields. """ self.run_button = QPushButton("Start Server") self.stop_button = QPushButton("Stop Server") self.existing_button = QPushButton("Use Existing Server") advanced_button = QRadioButton("Advanced") run_fields = QGroupBox("Run SuRVoS:") run_layout = QGridLayout() workspaces = os.listdir(CHROOT) self.workspaces_list = ComboBox() for s in workspaces: self.workspaces_list.addItem(key=s) run_layout.addWidget(QLabel("Workspace Name:"), 0, 0) self.ws_name_linedt_2 = QLineEdit( self.workspace_config["workspace_name"]) self.ws_name_linedt_2.setAlignment(Qt.AlignLeft) self.workspaces_list.setLineEdit(self.ws_name_linedt_2) # run_layout.addWidget(self.ws_name_linedt_2, 0, 1) run_layout.addWidget(self.workspaces_list, 0, 1) run_layout.addWidget(advanced_button, 1, 0) run_layout.addWidget(self.adv_run_fields, 2, 1) run_layout.addWidget(self.run_button, 3, 0, 1, 3) run_layout.addWidget(self.stop_button, 4, 0, 1, 3) run_layout.addWidget(self.existing_button, 5, 0, 1, 3) run_fields.setLayout(run_layout) advanced_button.toggled.connect(self.toggle_advanced) self.run_button.clicked.connect(self.run_clicked) self.stop_button.clicked.connect(self.stop_clicked) self.existing_button.clicked.connect(self.existing_clicked) return run_fields def get_login_username(self): try: user = getpass.getuser() except Exception: user = "" return user def refresh_chroot(self): workspaces = os.listdir(DataModel.g.CHROOT) self.workspaces_list.clear() for s in workspaces: self.workspaces_list.addItem(key=s) @pyqtSlot() def set_chroot(self): CHROOT = self.given_chroot_linedt.text() Config.update({"model": {"chroot": CHROOT}}) logger.debug(f"Setting CHROOT to {CHROOT}") DataModel.g.CHROOT = CHROOT self.refresh_chroot() @pyqtSlot() def launch_data_loader(self): """Load the dialog box widget to select data with data preview window and ROI selection.""" path = None int_h5_pth = None dialog = LoadDataDialog(self) result = dialog.exec_() self.roi_limits = None if result == QDialog.Accepted: path = dialog.winput.path.text() int_h5_pth = dialog.int_h5_pth.text() down_factor = dialog.downsample_spinner.value() if path and int_h5_pth: self.data_filepth_linedt.setText(path) self.h5_intpth_linedt.setText(int_h5_pth) self.downsample_spinner.setValue(down_factor) if dialog.roi_changed: self.roi_limits = tuple(map(str, dialog.get_roi_limits())) self.roi_fields.show() self.update_roi_fields_from_dialog() else: self.roi_fields.hide() def update_roi_fields_from_dialog(self): """Updates the ROI fields in the main window.""" x_start, x_end, y_start, y_end, z_start, z_end = self.roi_limits self.xstart_roi_val.setText(x_start) self.xend_roi_val.setText(x_end) self.ystart_roi_val.setText(y_start) self.yend_roi_val.setText(y_end) self.zstart_roi_val.setText(z_start) self.zend_roi_val.setText(z_end) @pyqtSlot() def toggle_advanced(self): """Controls displaying/hiding the advanced run fields on radio button toggle.""" rbutton = self.sender() if rbutton.isChecked(): self.adv_run_fields.show() else: self.adv_run_fields.hide() @pyqtSlot() def toggle_ssh(self): """Controls displaying/hiding the SSH fields on radio button toggle.""" rbutton = self.sender() if rbutton.isChecked(): self.adv_ssh_fields.show() else: self.adv_ssh_fields.hide() @pyqtSlot() def create_workspace_clicked(self): """Performs checks and coordinates workspace creation on button press.""" logger.debug("Creating workspace: ") # Set the path to the data file vol_path = Path(self.data_filepth_linedt.text()) if not vol_path.is_file(): err_str = f"No data file exists at {vol_path}!" logger.error(err_str) self.button_feedback_response(err_str, self.create_workspace_button, "maroon") else: self.workspace_config["datasets_dir"] = str(vol_path.parent) self.workspace_config["vol_fname"] = str(vol_path.name) dataset_name = self.h5_intpth_linedt.text() self.workspace_config["dataset_name"] = str(dataset_name).strip( "/") # Set the workspace name ws_name = self.ws_name_linedt_1.text() self.workspace_config["workspace_name"] = ws_name # Set the downsample factor ds_factor = self.downsample_spinner.value() self.workspace_config["downsample_by"] = ds_factor # Set the ROI limits if they exist if self.roi_limits: self.workspace_config["roi_limits"] = self.roi_limits try: response = init_ws(self.workspace_config) _, error = response if not error: self.button_feedback_response( "Workspace created sucessfully", self.create_workspace_button, "green", ) # Update the workspace name in the 'Run' section self.ws_name_linedt_2.setText(self.ws_name_linedt_1.text()) except WorkspaceException as e: logger.exception(e) self.button_feedback_response(str(e), self.create_workspace_button, "maroon") self.refresh_chroot() def button_feedback_response(self, message, button, colour_str, timeout=2): """Changes button colour and displays feedback message for a limited time period. Args: message (str): Message to display in button. button (PyQt5.QWidgets.QBushButton): The button to manipulate. colour_str (str): The standard CSS colour string or hex code describing the colour to change the button to. """ timeout *= 1000 msg_old = button.text() col_old = button.palette().button().color txt_col_old = button.palette().buttonText().color button.setText(message) button.setStyleSheet(f"background-color: {colour_str}; color: white") timer = QTimer() timer.singleShot( timeout, lambda: self.reset_button(button, msg_old, col_old, txt_col_old)) @pyqtSlot() def reset_button(self, button, msg_old, col_old, txt_col_old): """Sets a button back to its original display settings. Args: button (PyQt5.QWidgets.QBushButton): The button to manipulate. msg_old (str): Message to display in button. col_old (str): The standard CSS colour string or hex code describing the colour to change the button to. txt_col_old (str): The standard CSS colour string or hex code describing the colour to change the button text to. """ button.setStyleSheet(f"background-color: {col_old().name()}") button.setStyleSheet(f"color: {txt_col_old().name()}") button.setText(msg_old) button.update() @pyqtSlot() def output_config_clicked(self): """Outputs pipeline config YAML file on button click.""" out_fname = "pipeline_cfg.yml" logger.debug(f"Outputting pipeline config: {out_fname}") with open(out_fname, "w") as outfile: yaml.dump(self.pipeline_config, outfile, default_flow_style=False, sort_keys=False) def get_ssh_params(self): ssh_host = self.ssh_host_linedt.text() ssh_user = self.ssh_username_linedt.text() ssh_port = int(self.ssh_port_linedt.text()) return ssh_host, ssh_user, ssh_port def start_server_over_ssh(self): params = self.get_ssh_params() if not all(params): logger.error( "Not all SSH parameters given! Not connecting to SSH.") pass ssh_host, ssh_user, ssh_port = params # Pop up dialog to ask for password text, ok = QInputDialog.getText(None, "Login", f"Password for {ssh_user}@{ssh_host}", QLineEdit.Password) if ok and text: self.ssh_worker = SSHWorker(params, text, self.run_config) self.ssh_thread = QThread(self) self.ssh_worker.moveToThread(self.ssh_thread) self.ssh_worker.button_message_signal.connect( self.send_msg_to_run_button) self.ssh_worker.error_signal.connect(self.on_ssh_error) self.ssh_worker.finished.connect(self.start_client) self.ssh_worker.update_ip_linedt_signal.connect( self.update_ip_linedt) self.ssh_thread.started.connect( self.ssh_worker.start_server_over_ssh) self.ssh_thread.start() def closeEvent(self, event): reply = QMessageBox.question( self, "Quit", "Are you sure you want to quit? " "The server will be stopped.", QMessageBox.Yes | QMessageBox.No, QMessageBox.No, ) if reply == QMessageBox.Yes: event.accept() else: event.ignore() @pyqtSlot() def on_ssh_error(self): self.ssh_error = True @pyqtSlot(str) def update_ip_linedt(self, ip): self.server_ip_linedt.setText(ip) @pyqtSlot(list) def send_msg_to_run_button(self, param_list): self.button_feedback_response(param_list[0], self.run_button, param_list[1], param_list[2]) @pyqtSlot() def stop_clicked(self): logger.debug("Stopping server") if self.server_process is not None: self.server_process.kill() @pyqtSlot() def run_clicked(self): """Starts SuRVoS2 server and client as subprocesses when 'Run' button pressed. Raises: Exception: If survos.py not found. """ with progress(total=3) as pbar: pbar.set_description("Starting server...") pbar.update(1) self.ssh_error = ( False # Flag which will be set to True if there is an SSH error ) command_dir = os.path.abspath(os.path.dirname(__file__)) # os.getcwd() # Set current dir to survos root from pathlib import Path command_dir = Path( command_dir).absolute().parent.parent.parent.resolve() os.chdir(command_dir) self.script_fullname = os.path.join(command_dir, "survos.py") if not os.path.isfile(self.script_fullname): raise Exception("{}: Script not found".format( self.script_fullname)) # Retrieve the parameters from the fields TODO: Put some error checking in self.run_config["workspace_name"] = self.ws_name_linedt_2.text() self.run_config["server_port"] = self.server_port_linedt.text() # Temporary measure to check whether the workspace exists or not full_ws_path = os.path.join(Config["model.chroot"], self.run_config["workspace_name"]) if not os.path.isdir(full_ws_path): logger.error( f"No workspace can be found at {full_ws_path}, Not starting SuRVoS." ) self.button_feedback_response( f"Workspace {self.run_config['workspace_name']} does not appear to exist!", self.run_button, "maroon", ) return pbar.update(1) # Try some fancy SSH stuff here if self.ssh_button.isChecked(): self.start_server_over_ssh() else: self.server_process = subprocess.Popen([ "python", self.script_fullname, "start_server", self.run_config["workspace_name"], self.run_config["server_port"], DataModel.g.CHROOT, ]) try: outs, errs = self.server_process.communicate(timeout=10) print(f"OUTS: {outs, errs}") except subprocess.TimeoutExpired: pass # self.start_client() logger.info(f"setting remote: {self.server_port_linedt.text()}") remote_ip_port = "127.0.0.1:" + self.server_port_linedt.text() logger.info(f"setting remote: {remote_ip_port}") resp = Launcher.g.set_remote(remote_ip_port) logger.info(f"Response from server to setting remote: {resp}") cfg.ppw.clientEvent.emit({ "source": "server_tab", "data": "set_workspace", "workspace": self.ws_name_linedt_2.text(), }) cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "refresh", "value": None }) #cfg.ppw.clientEvent.emit({'data' : 'view_feature', 'feature_id' : '001_raw'}) pbar.update(1) @pyqtSlot() def existing_clicked(self): ssh_ip = self.server_ip_linedt.text() remote_ip_port = ssh_ip + ":" + self.server_port_linedt.text() logger.info(f"setting remote: {remote_ip_port}") resp = Launcher.g.set_remote(remote_ip_port) logger.info(f"Response from server to setting remote: {resp}") cfg.ppw.clientEvent.emit({ "source": "server_tab", "data": "set_workspace", "workspace": self.ws_name_linedt_2.text(), }) cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "refresh", "value": None }) def start_client(self): if not self.ssh_error: self.button_feedback_response("Starting Client.", self.run_button, "green", 7) self.run_config["server_ip"] = self.server_ip_linedt.text() self.client_process = subprocess.Popen([ "python", self.script_fullname, "nu_gui", self.run_config["workspace_name"], str(self.run_config["server_ip"]) + ":" + str(self.run_config["server_port"]), ])
class ButtonPanelWidget(QtWidgets.QWidget): clientEvent = Signal(object) def __init__(self, *args, **kwargs): QtWidgets.QWidget.__init__(self, *args, **kwargs) self.slice_mode = False self.slider = Slider(value=0, vmax=cfg.slice_max - 1) self.slider.setMinimumWidth(150) self.slider.sliderReleased.connect(self._params_updated) button_refresh = QPushButton("Refresh", self) button_refresh.clicked.connect(self.button_refresh_clicked) button_transfer = QPushButton("Transfer Layer") button_transfer.clicked.connect(self.button_transfer_clicked) self.button_slicemode = QPushButton("Slice mode", self) self.button_slicemode.clicked.connect(self.button_slicemode_clicked) workspaces = os.listdir(DataModel.g.CHROOT) self.workspaces_list = ComboBox() for s in workspaces: self.workspaces_list.addItem(key=s) workspaces_widget = HWidgets("Switch Workspaces:", self.workspaces_list) self.workspaces_list.setEditable(True) self.workspaces_list.activated[str].connect(self.workspaces_selected) self.hbox_layout0 = QtWidgets.QHBoxLayout() hbox_layout_ws = QtWidgets.QHBoxLayout() hbox_layout1 = QtWidgets.QHBoxLayout() self.hbox_layout0.addWidget(self.slider) self.slider.hide() hbox_layout_ws.addWidget(workspaces_widget) hbox_layout_ws.addWidget(button_refresh) hbox_layout1.addWidget(button_transfer) hbox_layout1.addWidget(self.button_slicemode) vbox = VBox(self, margin=(1, 1, 1, 1), spacing=2) vbox.addLayout(self.hbox_layout0) vbox.addLayout(hbox_layout1) vbox.addLayout(hbox_layout_ws) def refresh_workspaces(self): workspaces = os.listdir(DataModel.g.CHROOT) self.workspaces_list.clear() for s in workspaces: self.workspaces_list.addItem(key=s) self.slider.setMinimumWidth(cfg.base_dataset_shape[0]) def workspaces_selected(self): selected_workspace = self.workspaces_list.value() self.workspaces_list.blockSignals(True) self.workspaces_list.select(selected_workspace) self.workspaces_list.blockSignals(False) cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "set_workspace", "workspace": selected_workspace, }) def sessions_selected(self): cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "set_session", "session": self.session_list.value(), }) def button_setroi_clicked(self): roi_start = self.roi_start.value() roi_end = self.roi_end.value() roi = [ roi_start[0], roi_start[1], roi_start[2], roi_end[0], roi_end[1], roi_end[2], ] self.refresh_workspaces() cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "make_roi_ws", "roi": roi }) def button_refresh_clicked(self): self.refresh_workspaces() cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "empty_viewer", "value": None }) cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "refresh", "value": None }) def button_pause_save_clicked(self): if cfg.pause_save: self.button_pause_save.setText("Pause Saving to Server") else: self.button_pause_save.setText("Resume saving to Server") cfg.pause_save = not cfg.pause_save def button_transfer_clicked(self): cfg.ppw.clientEvent.emit({ "source": "panel_gui", "data": "transfer_layer", "value": None }) def button_slicemode_clicked(self): if self.slice_mode: self.slider.hide() self.button_slicemode.setText("Slice Mode") else: logger.info(f"Slice mode {cfg.slice_max}") self.slider.vmax = cfg.slice_max - 1 self.slider.setRange(0, cfg.slice_max - 1) self.slider.show() self.button_slicemode.setText("Volume Mode") self.slice_mode = not self.slice_mode cfg.ppw.clientEvent.emit({ "source": "button_slicemode", "data": "slice_mode", }) def _params_updated(self): cfg.ppw.clientEvent.emit({ "source": "slider", "data": "jump_to_slice", "frame": self.slider.value() })
class PipelinesPlugin(Plugin): __icon__ = "fa.picture-o" __pname__ = "pipelines" __views__ = ["slice_viewer"] __tab__ = "pipelines" def __init__(self, parent=None): super().__init__(parent=parent) self.pipeline_combo = ComboBox() self.vbox = VBox(self, spacing=4) self.vbox.addWidget(self.pipeline_combo) self.pipeline_combo.currentIndexChanged.connect(self.add_pipeline) self.existing_pipelines = dict() self._populate_pipelines() def _populate_pipelines(self): self.pipeline_params = {} self.pipeline_combo.clear() self.pipeline_combo.addItem("Add segmentation") result = Launcher.g.run("pipelines", "available", workspace=True) if not result: params = {} params["category"] = "superregion" params["name"] = "s0" params["type"] = "superregion_segment" result = {} result[0] = params self.pipeline_params["superregion_segment"] = { "sr_params": { "type": "sr2", } } else: all_categories = sorted(set(p["category"] for p in result)) for i, category in enumerate(all_categories): self.pipeline_combo.addItem(category) self.pipeline_combo.model().item(i + len(self.pipeline_params) + 1).setEnabled(False) for f in [p for p in result if p["category"] == category]: self.pipeline_params[f["name"]] = f["params"] self.pipeline_combo.addItem(f["name"]) def add_pipeline(self, idx): if idx <= 0: return if self.pipeline_combo.itemText(idx) == "": return logger.debug(f"Adding pipeline {self.pipeline_combo.itemText(idx)}") pipeline_type = self.pipeline_combo.itemText(idx) self.pipeline_combo.setCurrentIndex(0) params = dict(pipeline_type=pipeline_type, workspace=True) result = Launcher.g.run("pipelines", "create", **params) if result: fid = result["id"] ftype = result["kind"] fname = result["name"] self._add_pipeline_widget(fid, ftype, fname, True) _PipelineNotifier.notify() def _add_pipeline_widget(self, fid, ftype, fname, expand=False): widget = PipelineCard(fid, ftype, fname, self.pipeline_params[ftype]) widget.showContent(expand) self.vbox.addWidget(widget) self.existing_pipelines[fid] = widget return widget def clear(self): for pipeline in list(self.existing_pipelines.keys()): self.existing_pipelines.pop(pipeline).setParent(None) self.existing_pipelines = {} def setup(self): self._populate_pipelines() params = dict(workspace=DataModel.g.current_session + "@" + DataModel.g.current_workspace) result = Launcher.g.run("pipelines", "existing", **params) logger.debug(f"Pipeline result {result}") if result: # Remove pipelines that no longer exist in the server for pipeline in list(self.existing_pipelines.keys()): if pipeline not in result: self.existing_pipelines.pop(pipeline).setParent(None) # Populate with new pipelines if any for pipeline in sorted(result): if pipeline in self.existing_pipelines: continue params = result[pipeline] logger.debug(f"Pipeline params {params}") fid = params.pop("id", pipeline) ftype = params.pop("kind") fname = params.pop("name", pipeline) widget = self._add_pipeline_widget(fid, ftype, fname) widget.update_params(params) self.existing_pipelines[fid] = widget
class PipelineCard(Card): def __init__(self, fid, ftype, fname, fparams, parent=None): super().__init__(fname, removable=True, editable=True, collapsible=True, parent=parent) self.pipeline_id = fid self.pipeline_type = ftype self.pipeline_name = fname #from qtpy.QtWidgets import QProgressBar #self.pbar = QProgressBar(self) #self.add_row(self.pbar) self.params = fparams print(fparams) self.widgets = dict() if self.pipeline_type == "superregion_segment": logger.debug("Adding a superregion_segment pipeline") self._add_features_source() self._add_annotations_source() self._add_constrain_source() self._add_regions_source() self.ensembles = EnsembleWidget() self.ensembles.train_predict.connect(self.compute_pipeline) self.svm = SVMWidget() self.svm.predict.connect(self.compute_pipeline) self._add_classifier_choice() self._add_projection_choice() self._add_param("lam", type="FloatSlider", default=0.15) self._add_confidence_choice() elif self.pipeline_type == "rasterize_points": self._add_annotations_source() self._add_feature_source() self._add_objects_source() elif self.pipeline_type == "watershed": self._add_annotations_source() self._add_feature_source() elif self.pipeline_type == "predict_segmentation_fcn": self._add_annotations_source() self._add_feature_source() self._add_workflow_file() self._add_model_type() # self._add_patch_params() elif self.pipeline_type == "label_postprocess": self._add_annotations_source(label="Layer Over: ") self._add_annotations_source2(label="Layer Base: ") self.label_index = LineEdit(default=-1, parse=int) #widget = HWidgets("Selected label:", self.label_index, Spacing(35), stretch=1) #self.add_row(widget) #self.offset = LineEdit(default=-1, parse=int) #widget2 = HWidgets("Offset:", self.offset, Spacing(35), stretch=1) #self.add_row(widget2) elif self.pipeline_type == "cleaning": # self._add_objects_source() self._add_feature_source() self._add_annotations_source() elif self.pipeline_type == "train_2d_unet": self._add_annotations_source() self._add_feature_source() self._add_unet_2d_training_params() elif self.pipeline_type == "predict_2d_unet": self._add_annotations_source() self._add_feature_source() self._add_unet_2d_prediction_params() else: logger.debug(f"Unsupported pipeline type {self.pipeline_type}.") for pname, params in fparams.items(): if pname not in ["src", "dst"]: self._add_param(pname, **params) self._add_compute_btn() self._add_view_btn() def _add_model_type(self): self.model_type = ComboBox() self.model_type.addItem(key="unet3d") self.model_type.addItem(key="fpn3d") widget = HWidgets("Model type:", self.model_type, Spacing(35), stretch=0) self.add_row(widget) def _add_patch_params(self): self.patch_size = LineEdit3D(default=64, parse=int) self.add_row( HWidgets("Patch Size:", self.patch_size, Spacing(35), stretch=1)) def _add_unet_2d_training_params(self): self.add_row(HWidgets("Training Parameters:", Spacing(35), stretch=1)) self.cycles_frozen = LineEdit(default=8, parse=int) self.cycles_unfrozen = LineEdit(default=5, parse=int) self.add_row( HWidgets("No. Cycles Frozen:", self.cycles_frozen, "No. Cycles Unfrozen", self.cycles_unfrozen, stretch=1)) def _add_unet_2d_prediction_params(self): self.model_file_line_edit = LineEdit(default="Filepath", parse=str) model_input_btn = PushButton("Select Model", accent=True) model_input_btn.clicked.connect(self.get_model_path) self.radio_group = QtWidgets.QButtonGroup() self.radio_group.setExclusive(True) single_pp_rb = QRadioButton("Single plane") single_pp_rb.setChecked(True) self.radio_group.addButton(single_pp_rb, 1) triple_pp_rb = QRadioButton("Three plane") self.radio_group.addButton(triple_pp_rb, 3) self.add_row( HWidgets(self.model_file_line_edit, model_input_btn, Spacing(35))) self.add_row(HWidgets("Prediction Parameters:", Spacing(35), stretch=1)) self.add_row(HWidgets(single_pp_rb, triple_pp_rb, stretch=1)) def _add_workflow_file(self): self.filewidget = FileWidget(extensions="*.pt", save=False) self.add_row(self.filewidget) self.filewidget.path_updated.connect(self.load_data) def load_data(self, path): self.model_fullname = path print(f"Setting model fullname: {self.model_fullname}") def _add_view_btn(self): view_btn = PushButton("View", accent=True) view_btn.clicked.connect(self.view_pipeline) load_as_annotation_btn = PushButton("Load as annotation", accent=True) load_as_annotation_btn.clicked.connect(self.load_as_annotation) load_as_float_btn = PushButton("Load as image", accent=True) load_as_float_btn.clicked.connect(self.load_as_float) self.add_row( HWidgets(None, load_as_float_btn, load_as_annotation_btn, view_btn, Spacing(35))) def _add_refine_choice(self): self.refine_checkbox = CheckBox(checked=True) self.add_row( HWidgets("MRF Refinement:", self.refine_checkbox, Spacing(35), stretch=0)) def _add_confidence_choice(self): self.confidence_checkbox = CheckBox(checked=False) self.add_row( HWidgets("Confidence Map as Feature:", self.confidence_checkbox, Spacing(35), stretch=0)) def _add_objects_source(self): self.objects_source = ObjectComboBox(full=True) self.objects_source.fill() self.objects_source.setMaximumWidth(250) widget = HWidgets("Objects:", self.objects_source, Spacing(35), stretch=1) self.add_row(widget) def _add_classifier_choice(self): self.classifier_type = ComboBox() self.classifier_type.addItem(key="Ensemble") self.classifier_type.addItem(key="SVM") widget = HWidgets("Classifier:", self.classifier_type, Spacing(35), stretch=0) self.classifier_type.currentIndexChanged.connect( self._on_classifier_changed) self.clf_container = QtWidgets.QWidget() clf_vbox = VBox(self, spacing=4) clf_vbox.setContentsMargins(0, 0, 0, 0) self.clf_container.setLayout(clf_vbox) self.add_row(widget) self.add_row(self.clf_container, max_height=500) self.clf_container.layout().addWidget(self.ensembles) def _on_classifier_changed(self, idx): if idx == 0: self.clf_container.layout().addWidget(self.ensembles) self.svm.setParent(None) elif idx == 1: self.clf_container.layout().addWidget(self.svm) self.ensembles.setParent(None) def _add_projection_choice(self): self.projection_type = ComboBox() self.projection_type.addItem(key="None") self.projection_type.addItem(key="pca") self.projection_type.addItem(key="rbp") self.projection_type.addItem(key="rproj") self.projection_type.addItem(key="std") widget = HWidgets("Projection:", self.projection_type, Spacing(35), stretch=0) self.add_row(widget) def _add_feature_source(self): self.feature_source = FeatureComboBox() self.feature_source.fill() self.feature_source.setMaximumWidth(250) widget = HWidgets("Feature:", self.feature_source, Spacing(35), stretch=1) self.add_row(widget) def _add_features_source(self): self.features_source = MultiSourceComboBox() self.features_source.fill() self.features_source.setMaximumWidth(250) cfg.pipelines_features_source = self.features_source widget = HWidgets("Features:", self.features_source, Spacing(35), stretch=1) self.add_row(widget) def _add_constrain_source(self): print(self.annotations_source.value()) self.constrain_mask_source = AnnotationComboBox(header=(None, "None"), full=True) self.constrain_mask_source.fill() self.constrain_mask_source.setMaximumWidth(250) widget = HWidgets("Constrain mask:", self.constrain_mask_source, Spacing(35), stretch=1) self.add_row(widget) def _add_annotations_source(self, label="Annotation"): self.annotations_source = LevelComboBox(full=True) self.annotations_source.fill() self.annotations_source.setMaximumWidth(250) widget = HWidgets(label, self.annotations_source, Spacing(35), stretch=1) self.add_row(widget) def _add_annotations_source2(self, label="Annotation 2"): self.annotations_source2 = LevelComboBox(full=True) self.annotations_source2.fill() self.annotations_source2.setMaximumWidth(250) widget = HWidgets(label, self.annotations_source2, Spacing(35), stretch=1) self.add_row(widget) def _add_pipelines_source(self): self.pipelines_source = PipelinesComboBox() self.pipelines_source.fill() self.pipelines_source.setMaximumWidth(250) widget = HWidgets("Segmentation:", self.pipelines_source, Spacing(35), stretch=1) self.add_row(widget) def _add_regions_source(self): self.regions_source = RegionComboBox(full=True) # SourceComboBox() self.regions_source.fill() self.regions_source.setMaximumWidth(250) widget = HWidgets("Superregions:", self.regions_source, Spacing(35), stretch=1) cfg.pipelines_regions_source = self.regions_source self.add_row(widget) def _add_param(self, name, title=None, type="String", default=None): if type == "Int": p = LineEdit(default=0, parse=int) elif type == "FloatSlider": p = RealSlider(value=0.0, vmax=1, vmin=0) title = "MRF Refinement Amount:" elif type == "Float": p = LineEdit(default=0.0, parse=float) title = title elif type == "FloatOrVector": p = LineEdit3D(default=0, parse=float) elif type == "IntOrVector": p = LineEdit3D(default=0, parse=int) elif type == "SmartBoolean": p = CheckBox(checked=True) else: p = None if title is None: title = name if p: self.widgets[name] = p self.add_row(HWidgets(None, title, p, Spacing(35))) def _add_compute_btn(self): compute_btn = PushButton("Compute", accent=True) compute_btn.clicked.connect(self.compute_pipeline) self.add_row(HWidgets(None, compute_btn, Spacing(35))) def update_params(self, params): logger.debug(f"Pipeline update params {params}") for k, v in params.items(): if k in self.widgets: self.widgets[k].setValue(v) if "anno_id" in params: if params["anno_id"] is not None: self.annotations_source.select( os.path.join("annotations/", params["anno_id"])) if "object_id" in params: if params["object_id"] is not None: self.objects_source.select( os.path.join("objects/", params["object_id"])) if "feature_id" in params: for source in params["feature_id"]: self.feature_source.select(os.path.join("features/", source)) if "feature_ids" in params: for source in params["feature_ids"]: self.features_source.select(os.path.join("features/", source)) if "region_id" in params: if params["region_id"] is not None: self.regions_source.select( os.path.join("regions/", params["region_id"])) if "constrain_mask" in params: if (params["constrain_mask"] is not None and params["constrain_mask"] != "None"): import ast constrain_mask_dict = ast.literal_eval( params["constrain_mask"]) print(constrain_mask_dict) constrain_mask_source = (constrain_mask_dict["level"] + ":" + str(constrain_mask_dict["idx"])) print(f"Constrain mask source {constrain_mask_source}") self.constrain_mask_source.select(constrain_mask_source) def card_deleted(self): params = dict(pipeline_id=self.pipeline_id, workspace=True) result = Launcher.g.run("pipelines", "remove", **params) if result["done"]: self.setParent(None) _PipelineNotifier.notify() cfg.ppw.clientEvent.emit({ "source": "pipelines", "data": "remove_layer", "layer_name": self.pipeline_id, }) def view_pipeline(self): logger.debug(f"View pipeline_id {self.pipeline_id}") with progress(total=2) as pbar: pbar.set_description("Viewing feature") pbar.update(1) if self.annotations_source: if self.annotations_source.value(): level_id = str(self.annotations_source.value().rsplit( "/", 1)[-1]) else: level_id = '001_level' logger.debug(f"Assigning annotation level {level_id}") cfg.ppw.clientEvent.emit({ "source": "pipelines", "data": "view_pipeline", "pipeline_id": self.pipeline_id, "level_id": level_id, }) pbar.update(1) def get_model_path(self): workspace_path = os.path.join(DataModel.g.CHROOT, DataModel.g.current_workspace) self.model_path, _ = QtWidgets.QFileDialog.getOpenFileName( self, ("Select model"), workspace_path, ("Model files (*.zip)")) self.model_file_line_edit.setValue(self.model_path) def load_as_float(self): logger.debug(f"Loading prediction {self.pipeline_id} as float image.") # get pipeline output src = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines") with DatasetManager(src, out=None, dtype="uint32", fillvalue=0) as DM: src_arr = DM.sources[0][:] # create new float image params = dict(feature_type="raw", workspace=True) result = Launcher.g.run("features", "create", **params) if result: fid = result["id"] ftype = result["kind"] fname = result["name"] logger.debug( f"Created new object in workspace {fid}, {ftype}, {fname}") dst = DataModel.g.dataset_uri(fid, group="features") with DatasetManager(dst, out=dst, dtype="float32", fillvalue=0) as DM: DM.out[:] = src_arr cfg.ppw.clientEvent.emit({ "source": "workspace_gui", "data": "refresh", "value": None }) def load_as_annotation(self): logger.debug(f"Loading prediction {self.pipeline_id} as annotation.") # get pipeline output src = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines") with DatasetManager(src, out=None, dtype="uint32", fillvalue=0) as DM: src_arr = DM.sources[0][:] label_values = np.unique(src_arr) # create new level params = dict(level=self.pipeline_id, workspace=True) result = Launcher.g.run("annotations", "add_level", workspace=True) # create a blank label for each unique value in the pipeline output array if result: level_id = result["id"] for v in label_values: params = dict( level=level_id, idx=int(v), name=str(v), color="#11FF11", workspace=True, ) label_result = Launcher.g.run("annotations", "add_label", **params) params = dict( level=str(self.annotations_source.value().rsplit("/", 1)[-1]), workspace=True, ) anno_result = Launcher.g.run("annotations", "get_levels", **params)[0] params = dict(level=str(level_id), workspace=True) level_result = Launcher.g.run("annotations", "get_levels", **params)[0] try: # set the new level color mapping to the mapping from the pipeline for v in level_result["labels"].keys(): if v in anno_result["labels"]: label_hex = anno_result["labels"][v]["color"] label = dict( idx=int(v), name=str(v), color=label_hex, ) params = dict(level=result["id"], workspace=True) label_result = Launcher.g.run("annotations", "update_label", **params, **label) except Exception as err: logger.debug(f"Exception {err}") fid = result["id"] ftype = result["kind"] fname = result["name"] logger.debug( f"Created new object in workspace {fid}, {ftype}, {fname}") # set levels array to pipeline output array dst = DataModel.g.dataset_uri(fid, group="annotations") with DatasetManager(dst, out=dst, dtype="uint32", fillvalue=0) as DM: DM.out[:] = src_arr cfg.ppw.clientEvent.emit({ "source": "workspace_gui", "data": "refresh", "value": None }) def setup_params_superregion_segment(self, dst): feature_names_list = [ n.rsplit("/", 1)[-1] for n in self.features_source.value() ] src_grp = None if self.annotations_source.currentIndex( ) == 0 else "pipelines" src = DataModel.g.dataset_uri( self.annotations_source.value().rsplit("/", 1)[-1], group="annotations", ) all_params = dict(src=src, modal=True) all_params["workspace"] = DataModel.g.current_workspace logger.info(f"Setting src to {self.annotations_source.value()} ") all_params["region_id"] = str(self.regions_source.value().rsplit( "/", 1)[-1]) all_params["feature_ids"] = feature_names_list all_params["anno_id"] = str(self.annotations_source.value().rsplit( "/", 1)[-1]) if self.constrain_mask_source.value() != None: all_params["constrain_mask"] = self.constrain_mask_source.value( ) # .rsplit("/", 1)[-1] else: all_params["constrain_mask"] = "None" all_params["dst"] = dst all_params["refine"] = self.widgets["refine"].value() all_params["lam"] = self.widgets["lam"].value() all_params["classifier_type"] = self.classifier_type.value() all_params["projection_type"] = self.projection_type.value() all_params["confidence"] = self.confidence_checkbox.value() if self.classifier_type.value() == "Ensemble": all_params["classifier_params"] = self.ensembles.get_params() else: all_params["classifier_params"] = self.svm.get_params() return all_params def setup_params_rasterize_points(self, dst): src = DataModel.g.dataset_uri(self.feature_source.value(), group="features") all_params = dict(src=src, modal=True) all_params["workspace"] = DataModel.g.current_workspace # all_params["anno_id"] = str( # self.annotations_source.value().rsplit("/", 1)[-1] # ) all_params["feature_id"] = self.feature_source.value() all_params["object_id"] = str(self.objects_source.value()) all_params["acwe"] = self.widgets["acwe"].value() # all_params["object_scale"] = self.widgets["object_scale"].value() # all_params["object_offset"] = self.widgets["object_offset"].value() all_params["dst"] = dst return all_params def setup_params_watershed(self, dst): src = DataModel.g.dataset_uri(self.feature_source.value(), group="features") all_params = dict(src=src, dst=dst, modal=True) all_params["workspace"] = DataModel.g.current_workspace all_params["dst"] = self.pipeline_id all_params["anno_id"] = str(self.annotations_source.value().rsplit( "/", 1)[-1]) return all_params def setup_params_predict_segmentation_fcn(self, dst): src = DataModel.g.dataset_uri(self.feature_source.value(), group="features") all_params = dict(src=src, dst=dst, modal=True) all_params["workspace"] = DataModel.g.current_workspace all_params["anno_id"] = str(self.annotations_source.value().rsplit( "/", 1)[-1]) all_params["feature_id"] = self.feature_source.value() all_params["model_fullname"] = self.model_fullname all_params["model_type"] = self.model_type.value() all_params["dst"] = self.pipeline_id return all_params def setup_params_label_postprocess(self, dst): all_params = dict(modal=True) all_params["workspace"] = DataModel.g.current_workspace print(self.annotations_source.value()) if (self.annotations_source.value()): all_params["level_over"] = str( self.annotations_source.value().rsplit("/", 1)[-1]) else: all_params["level_over"] = "None" all_params["level_base"] = str(self.annotations_source2.value().rsplit( "/", 1)[-1]) all_params["dst"] = dst #all_params["selected_label"] = int(self.label_index.value()) #all_params["offset"] = int(self.offset.value()) all_params["selected_label"] = int( self.widgets["selected_label"].value()) all_params["offset"] = int(self.widgets["offset"].value()) return all_params def setup_params_cleaning(self, dst): all_params = dict(dst=dst, modal=True) all_params["workspace"] = DataModel.g.current_workspace all_params["feature_id"] = str(self.feature_source.value()) # all_params["object_id"] = str(self.objects_source.value()) return all_params def setup_params_train_2d_unet(self, dst): src = DataModel.g.dataset_uri(self.feature_source.value(), group="features") all_params = dict(src=src, dst=dst, modal=True) all_params["workspace"] = DataModel.g.current_workspace all_params["feature_id"] = str(self.feature_source.value()) all_params["anno_id"] = str(self.annotations_source.value().rsplit( "/", 1)[-1]) all_params["unet_train_params"] = dict( cyc_frozen=self.cycles_frozen.value(), cyc_unfrozen=self.cycles_unfrozen.value()) return all_params def setup_params_predict_2d_unet(self, dst): src = DataModel.g.dataset_uri(self.feature_source.value(), group="features") all_params = dict(src=src, dst=dst, modal=True) all_params["workspace"] = DataModel.g.current_workspace all_params["feature_id"] = str(self.feature_source.value()) all_params["anno_id"] = str(self.annotations_source.value().rsplit( "/", 1)[-1]) all_params["model_path"] = str(self.model_file_line_edit.value()) all_params["no_of_planes"] = self.radio_group.checkedId() return all_params def compute_pipeline(self): dst = DataModel.g.dataset_uri(self.pipeline_id, group="pipelines") with progress(total=3) as pbar: pbar.set_description("Calculating pipeline") pbar.update(1) try: if self.pipeline_type == "superregion_segment": all_params = self.setup_params_superregion_segment(dst) elif self.pipeline_type == "rasterize_points": all_params = self.setup_params_rasterize_points(dst) elif self.pipeline_type == "watershed": all_params = self.setup_params_watershed(dst) elif self.pipeline_type == "predict_segmentation_fcn": all_params = self.setup_params_predict_segmentation_fcn( dst) elif self.pipeline_type == "label_postprocess": all_params = self.setup_params_label_postprocess(dst) elif self.pipeline_type == "cleaning": all_params = self.setup_params_cleaning(dst) elif self.pipeline_type == "train_2d_unet": all_params = self.setup_params_train_2d_unet(dst) elif self.pipeline_type == "predict_2d_unet": all_params = self.setup_params_predict_2d_unet(dst) else: logger.warning( f"No action exists for pipeline: {self.pipeline_type}") all_params.update( {k: v.value() for k, v in self.widgets.items()}) logger.info( f"Computing pipelines {self.pipeline_type} {all_params}") try: pbar.update(1) result = Launcher.g.run("pipelines", self.pipeline_type, **all_params) print(result) except Exception as err: print(err) if result is not None: pbar.update(1) except Exception as e: print(e) def card_title_edited(self, newtitle): params = dict(pipeline_id=self.pipeline_id, new_name=newtitle, workspace=True) result = Launcher.g.run("pipelines", "rename", **params) if result["done"]: _PipelineNotifier.notify() return result["done"]
class EnsembleWidget(QtWidgets.QWidget): train_predict = Signal(dict) def __init__(self, parent=None): super(EnsembleWidget, self).__init__(parent=parent) vbox = QtWidgets.QVBoxLayout() vbox.setContentsMargins(0, 0, 0, 0) self.setLayout(vbox) self.type_combo = ComboBox() self.type_combo.addCategory("Ensemble Type:") self.type_combo.addItem("Random Forest") self.type_combo.addItem("ExtraRandom Forest") self.type_combo.addItem("AdaBoost") self.type_combo.addItem("GradientBoosting") #self.type_combo.addItem("XGBoost") self.type_combo.currentIndexChanged.connect(self.on_ensemble_changed) vbox.addWidget(self.type_combo) self.ntrees = LineEdit(default=100, parse=int) self.depth = LineEdit(default=15, parse=int) self.lrate = LineEdit(default=1.0, parse=float) self.subsample = LineEdit(default=1.0, parse=float) vbox.addWidget( HWidgets( QtWidgets.QLabel("# Trees:"), self.ntrees, QtWidgets.QLabel("Max Depth:"), self.depth, stretch=[0, 1, 0, 1], )) vbox.addWidget( HWidgets( QtWidgets.QLabel("Learn Rate:"), self.lrate, QtWidgets.QLabel("Subsample:"), self.subsample, stretch=[0, 1, 0, 1], )) # self.btn_train_predict = PushButton('Train & Predict') # self.btn_train_predict.clicked.connect(self.on_train_predict_clicked) self.n_jobs = LineEdit(default=10, parse=int) vbox.addWidget(HWidgets("Num Jobs", self.n_jobs)) def on_ensemble_changed(self, idx): if idx == 2: self.ntrees.setDefault(50) else: self.ntrees.setDefault(100) if idx == 3: self.lrate.setDefault(0.1) self.depth.setDefault(3) else: self.lrate.setDefault(1.0) self.depth.setDefault(15) def on_train_predict_clicked(self): ttype = ["rf", "erf", "ada", "gbf", "xgb"] params = { "clf": "ensemble", "type": ttype[self.type_combo.currentIndex()], "n_estimators": self.ntrees.value(), "max_depth": self.depth.value(), "learning_rate": self.lrate.value(), "subsample": self.subsample.value(), "n_jobs": self.n_jobs.value(), } self.train_predict.emit(params) def get_params(self): ttype = ["rf", "erf", "ada", "gbf", "xgb"] if self.type_combo.currentIndex() - 1 == 0: current_index = 0 else: current_index = self.type_combo.currentIndex() - 1 logger.debug(f"Ensemble type_combo index: {current_index}") params = { "clf": "ensemble", "type": ttype[current_index], "n_estimators": self.ntrees.value(), "max_depth": self.depth.value(), "learning_rate": self.lrate.value(), "subsample": self.subsample.value(), "n_jobs": self.n_jobs.value(), } return params
class FrontEndRunner(QWidget): """Main FrontEnd Runner window for creating workspace and starting SuRVoS2.""" def __init__(self, run_config, workspace_config, pipeline_config, *args, **kwargs): super().__init__() self.run_config = run_config self.workspace_config = workspace_config self.pipeline_config = pipeline_config self.server_process = None self.client_process = None pipeline_config_ptree = self.init_ptree(self.pipeline_config, name="Pipeline") self.layout = QVBoxLayout() tabwidget = QTabWidget() tab1 = QWidget() tab2 = QWidget() tabwidget.addTab(tab1, "Setup and Start Survos") tabwidget.addTab(tab2, "Pipeline") self.create_workspace_button = QPushButton("Create workspace") tab1.layout = QVBoxLayout() tab1.setLayout(tab1.layout) chroot_fields = self.get_chroot_fields() tab1.layout.addWidget(chroot_fields) self.setup_adv_run_fields() self.adv_run_fields.hide() run_fields = self.get_run_fields() tab1.layout.addWidget(run_fields) output_config_button = QPushButton("Save config") tab2.layout = QVBoxLayout() tab2.setLayout(tab2.layout) tab2.layout.addWidget(pipeline_config_ptree) tab2.layout.addWidget(output_config_button) self.create_workspace_button.clicked.connect(self.create_workspace_clicked) output_config_button.clicked.connect(self.output_config_clicked) self.layout.addWidget(tabwidget) self.setGeometry(300, 300, 600, 400) self.setWindowTitle("SuRVoS Settings Editor") current_fpth = os.path.dirname(os.path.abspath(__file__)) self.setWindowIcon(QIcon(os.path.join(current_fpth, "resources", "logo.png"))) self.setLayout(self.layout) self.show() def get_workspace_fields(self): """Gets the QGroupBox that contains all the fields for setting up the workspace. Returns: PyQt5.QWidgets.GroupBox: GroupBox with workspace fields. """ select_data_button = QPushButton("Select") workspace_fields = QGroupBox("Create Workspace:") wf_layout = QGridLayout() wf_layout.addWidget(QLabel("Data File Path:"), 0, 0) current_data_path = Path( self.workspace_config["datasets_dir"], self.workspace_config["vol_fname"] ) self.data_filepth_linedt = QLineEdit(str(current_data_path)) wf_layout.addWidget(self.data_filepth_linedt, 1, 0, 1, 2) wf_layout.addWidget(select_data_button, 1, 2) wf_layout.addWidget(QLabel("HDF5 Internal Data Path:"), 2, 0, 1, 1) ws_dataset_name = self.workspace_config["dataset_name"] internal_h5_path = ( ws_dataset_name if str(ws_dataset_name).startswith("/") else "/" + ws_dataset_name ) self.h5_intpth_linedt = QLineEdit(internal_h5_path) wf_layout.addWidget(self.h5_intpth_linedt, 2, 1, 1, 1) wf_layout.addWidget(QLabel("Workspace Name:"), 3, 0) self.ws_name_linedt_1 = QLineEdit(self.workspace_config["workspace_name"]) wf_layout.addWidget(self.ws_name_linedt_1, 3, 1) wf_layout.addWidget(QLabel("Downsample Factor:"), 4, 0) self.downsample_spinner = QSpinBox() self.downsample_spinner.setRange(1, 10) self.downsample_spinner.setSpecialValueText("None") self.downsample_spinner.setMaximumWidth(60) self.downsample_spinner.setValue(int(self.workspace_config["downsample_by"])) wf_layout.addWidget(self.downsample_spinner, 4, 1, 1, 1) # ROI self.setup_roi_fields() wf_layout.addWidget(self.roi_fields, 4, 2, 1, 2) self.roi_fields.hide() wf_layout.addWidget(self.create_workspace_button, 5, 0, 1, 3) workspace_fields.setLayout(wf_layout) select_data_button.clicked.connect(self.launch_data_loader) return workspace_fields def setup_roi_fields(self): """Sets up the QGroupBox that displays the ROI dimensions, if selected.""" self.roi_fields = QGroupBox("ROI:") roi_fields_layout = QHBoxLayout() # z roi_fields_layout.addWidget(QLabel("z:"), 0) self.zstart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.zstart_roi_val, 1) roi_fields_layout.addWidget(QLabel("-"), 2) self.zend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.zend_roi_val, 3) # y roi_fields_layout.addWidget(QLabel("y:"), 4) self.ystart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.ystart_roi_val, 5) roi_fields_layout.addWidget(QLabel("-"), 6) self.yend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.yend_roi_val, 7) # x roi_fields_layout.addWidget(QLabel("x:"), 8) self.xstart_roi_val = QLabel("0") roi_fields_layout.addWidget(self.xstart_roi_val, 9) roi_fields_layout.addWidget(QLabel("-"), 10) self.xend_roi_val = QLabel("0") roi_fields_layout.addWidget(self.xend_roi_val, 11) self.roi_fields.setLayout(roi_fields_layout) def setup_adv_run_fields(self): """Sets up the QGroupBox that displays the advanced optiona for starting SuRVoS2.""" self.adv_run_fields = QGroupBox("Advanced Run Settings:") adv_run_layout = QGridLayout() adv_run_layout.addWidget(QLabel("Server IP Address:"), 0, 0) self.server_ip_linedt = QLineEdit(self.run_config["server_ip"]) adv_run_layout.addWidget(self.server_ip_linedt, 0, 1) adv_run_layout.addWidget(QLabel("Server Port:"), 1, 0) self.server_port_linedt = QLineEdit(self.run_config["server_port"]) adv_run_layout.addWidget(self.server_port_linedt, 1, 1) # SSH Info self.ssh_button = QRadioButton("Use SSH") self.ssh_button.setAutoExclusive(False) adv_run_layout.addWidget(self.ssh_button, 0, 2) ssh_flag = self.run_config.get("use_ssh", False) if ssh_flag: self.ssh_button.setChecked(True) self.ssh_button.toggled.connect(self.toggle_ssh) self.adv_ssh_fields = QGroupBox("SSH Settings:") adv_ssh_layout = QGridLayout() adv_ssh_layout.setColumnStretch(2, 2) ssh_host_label = QLabel("Host") self.ssh_host_linedt = QLineEdit(self.run_config.get("ssh_host", "")) adv_ssh_layout.addWidget(ssh_host_label, 0, 0) adv_ssh_layout.addWidget(self.ssh_host_linedt, 0, 1, 1, 2) ssh_user_label = QLabel("Username") self.ssh_username_linedt = QLineEdit(self.get_login_username()) adv_ssh_layout.addWidget(ssh_user_label, 1, 0) adv_ssh_layout.addWidget(self.ssh_username_linedt, 1, 1, 1, 2) ssh_port_label = QLabel("Port") self.ssh_port_linedt = QLineEdit(self.run_config.get("ssh_port", "")) adv_ssh_layout.addWidget(ssh_port_label, 2, 0) adv_ssh_layout.addWidget(self.ssh_port_linedt, 2, 1, 1, 2) self.adv_ssh_fields.setLayout(adv_ssh_layout) adv_run_layout.addWidget(self.adv_ssh_fields, 1, 2, 2, 5) self.adv_run_fields.setLayout(adv_run_layout) def get_run_fields(self): """Gets the QGroupBox that contains the fields for starting SuRVoS. Returns: PyQt5.QWidgets.GroupBox: GroupBox with run fields. """ self.run_button = QPushButton("Run SuRVoS") advanced_button = QRadioButton("Advanced") run_fields = QGroupBox("Run SuRVoS:") run_layout = QGridLayout() run_layout.addWidget(QLabel("Workspace Name:"), 0, 0) workspaces = os.listdir(CHROOT) self.workspaces_list = ComboBox() for s in workspaces: self.workspaces_list.addItem(key=s) run_layout.addWidget(self.workspaces_list, 0, 0) self.ws_name_linedt_2 = QLineEdit(self.workspace_config["workspace_name"]) self.ws_name_linedt_2.setAlignment(Qt.AlignLeft) run_layout.addWidget(self.ws_name_linedt_2, 0, 2) run_layout.addWidget(advanced_button, 1, 0) run_layout.addWidget(self.adv_run_fields, 2, 1) run_layout.addWidget(self.run_button, 3, 0, 1, 3) run_fields.setLayout(run_layout) advanced_button.toggled.connect(self.toggle_advanced) self.run_button.clicked.connect(self.run_clicked) return run_fields def get_login_username(self): try: user = getpass.getuser() except Exception: user = "" return user @pyqtSlot() def launch_data_loader(self): """Load the dialog box widget to select data with data preview window and ROI selection.""" path = None int_h5_pth = None dialog = LoadDataDialog(self) result = dialog.exec_() self.roi_limits = None if result == QDialog.Accepted: path = dialog.winput.path.text() int_h5_pth = dialog.int_h5_pth.text() down_factor = dialog.downsample_spinner.value() if path and int_h5_pth: self.data_filepth_linedt.setText(path) self.h5_intpth_linedt.setText(int_h5_pth) self.downsample_spinner.setValue(down_factor) if dialog.roi_changed: self.roi_limits = tuple(map(str, dialog.get_roi_limits())) self.roi_fields.show() self.update_roi_fields_from_dialog() else: self.roi_fields.hide() def update_roi_fields_from_dialog(self): """Updates the ROI fields in the main window.""" x_start, x_end, y_start, y_end, z_start, z_end = self.roi_limits self.xstart_roi_val.setText(x_start) self.xend_roi_val.setText(x_end) self.ystart_roi_val.setText(y_start) self.yend_roi_val.setText(y_end) self.zstart_roi_val.setText(z_start) self.zend_roi_val.setText(z_end) @pyqtSlot() def toggle_advanced(self): """Controls displaying/hiding the advanced run fields on radio button toggle.""" rbutton = self.sender() if rbutton.isChecked(): self.adv_run_fields.show() else: self.adv_run_fields.hide() @pyqtSlot() def toggle_ssh(self): """Controls displaying/hiding the SSH fields on radio button toggle.""" rbutton = self.sender() if rbutton.isChecked(): self.adv_ssh_fields.show() else: self.adv_ssh_fields.hide() def setup_ptree_params(self, p, config_dict): def parameter_tree_change(param, changes): for param, change, data in changes: path = p.childPath(param) if path is not None: childName = ".".join(path) else: childName = param.name() sibs = param.parent().children() config_dict[path[-1]] = data p.sigTreeStateChanged.connect(parameter_tree_change) def valueChanging(param, value): pass for child in p.children(): child.sigValueChanging.connect(valueChanging) for ch2 in child.children(): ch2.sigValueChanging.connect(valueChanging) return p def dict_to_params(self, param_dict, name="Group"): ptree_param_dicts = [] ctr = 0 for key in param_dict.keys(): entry = param_dict[key] if type(entry) == str: d = {"name": key, "type": "str", "value": entry} elif type(entry) == int: d = {"name": key, "type": "int", "value": entry} elif type(entry) == list: d = {"name": key, "type": "list", "values": entry} elif type(entry) == float: d = {"name": key, "type": "float", "value": entry} elif type(entry) == bool: d = {"name": key, "type": "bool", "value": entry} elif type(entry) == dict: d = self.dict_to_params(entry, name="Subgroup")[0] d["name"] = key + "_" + str(ctr) ctr += 1 else: print(f"Can't handle type {type(entry)}") ptree_param_dicts.append(d) ptree_init = [{"name": name, "type": "group", "children": ptree_param_dicts}] return ptree_init def init_ptree(self, config_dict, name="Group"): ptree_init = self.dict_to_params(config_dict, name) parameters = Parameter.create( name="ptree_init", type="group", children=ptree_init ) params = self.setup_ptree_params(parameters, config_dict) ptree = ParameterTree() ptree.setParameters(params, showTop=False) return ptree @pyqtSlot() def create_workspace_clicked(self): """Performs checks and coordinates workspace creation on button press.""" logger.debug("Creating workspace: ") # Set the path to the data file vol_path = Path(self.data_filepth_linedt.text()) if not vol_path.is_file(): err_str = f"No data file exists at {vol_path}!" logger.error(err_str) self.button_feedback_response( err_str, self.create_workspace_button, "maroon" ) else: self.workspace_config["datasets_dir"] = str(vol_path.parent) self.workspace_config["vol_fname"] = str(vol_path.name) dataset_name = self.h5_intpth_linedt.text() self.workspace_config["dataset_name"] = str(dataset_name).strip("/") # Set the workspace name ws_name = self.ws_name_linedt_1.text() self.workspace_config["workspace_name"] = ws_name # Set the downsample factor ds_factor = self.downsample_spinner.value() self.workspace_config["downsample_by"] = ds_factor # Set the ROI limits if they exist if self.roi_limits: self.workspace_config["roi_limits"] = self.roi_limits try: response = init_ws(self.workspace_config) _, error = response if not error: self.button_feedback_response( "Workspace created sucessfully", self.create_workspace_button, "green", ) # Update the workspace name in the 'Run' section self.ws_name_linedt_2.setText(self.ws_name_linedt_1.text()) except WorkspaceException as e: logger.exception(e) self.button_feedback_response( str(e), self.create_workspace_button, "maroon" ) def button_feedback_response(self, message, button, colour_str, timeout=2): """Changes button colour and displays feedback message for a limited time period. Args: message (str): Message to display in button. button (PyQt5.QWidgets.QBushButton): The button to manipulate. colour_str (str): The standard CSS colour string or hex code describing the colour to change the button to. """ timeout *= 1000 msg_old = button.text() col_old = button.palette().button().color txt_col_old = button.palette().buttonText().color button.setText(message) button.setStyleSheet(f"background-color: {colour_str}; color: white") timer = QTimer() timer.singleShot( timeout, lambda: self.reset_button(button, msg_old, col_old, txt_col_old) ) @pyqtSlot() def reset_button(self, button, msg_old, col_old, txt_col_old): """Sets a button back to its original display settings. Args: button (PyQt5.QWidgets.QBushButton): The button to manipulate. msg_old (str): Message to display in button. col_old (str): The standard CSS colour string or hex code describing the colour to change the button to. txt_col_old (str): The standard CSS colour string or hex code describing the colour to change the button text to. """ button.setStyleSheet(f"background-color: {col_old().name()}") button.setStyleSheet(f"color: {txt_col_old().name()}") button.setText(msg_old) button.update() @pyqtSlot() def output_config_clicked(self): """Outputs pipeline config YAML file on button click.""" out_fname = "pipeline_cfg.yml" logger.debug(f"Outputting pipeline config: {out_fname}") with open(out_fname, "w") as outfile: yaml.dump( self.pipeline_config, outfile, default_flow_style=False, sort_keys=False ) def get_ssh_params(self): ssh_host = self.ssh_host_linedt.text() ssh_user = self.ssh_username_linedt.text() ssh_port = int(self.ssh_port_linedt.text()) return ssh_host, ssh_user, ssh_port def start_server_over_ssh(self): params = self.get_ssh_params() if not all(params): logger.error("Not all SSH parameters given! Not connecting to SSH.") ssh_host, ssh_user, ssh_port = params # Pop up dialog to ask for password text, ok = QInputDialog.getText( None, "Login", f"Password for {ssh_user}@{ssh_host}", QLineEdit.Password ) if ok and text: self.ssh_worker = SSHWorker(params, text, self.run_config) self.ssh_thread = QThread(self) self.ssh_worker.moveToThread(self.ssh_thread) self.ssh_worker.button_message_signal.connect(self.send_msg_to_run_button) self.ssh_worker.error_signal.connect(self.on_ssh_error) self.ssh_worker.finished.connect(self.start_client) self.ssh_worker.update_ip_linedt_signal.connect(self.update_ip_linedt) self.ssh_thread.started.connect(self.ssh_worker.start_server_over_ssh) self.ssh_thread.start() def closeEvent(self, event): reply = QMessageBox.question( self, "Quit", "Are you sure you want to quit? " "The server will be stopped.", QMessageBox.Yes | QMessageBox.No, QMessageBox.No, ) if reply == QMessageBox.Yes: event.accept() else: event.ignore() @pyqtSlot() def on_ssh_error(self): self.ssh_error = True @pyqtSlot(str) def update_ip_linedt(self, ip): self.server_ip_linedt.setText(ip) @pyqtSlot(list) def send_msg_to_run_button(self, param_list): self.button_feedback_response( param_list[0], self.run_button, param_list[1], param_list[2] ) @pyqtSlot() def run_clicked(self): """Starts SuRVoS2 server and client as subprocesses when 'Run' button pressed. Raises: Exception: If survos.py not found. """ self.ssh_error = ( False # Flag which will be set to True if there is an SSH error ) command_dir = os.getcwd() self.script_fullname = os.path.join(command_dir, "survos.py") if not os.path.isfile(self.script_fullname): raise Exception("{}: Script not found".format(self.script_fullname)) # Retrieve the parameters from the fields TODO: Put some error checking in self.run_config["workspace_name"] = self.ws_name_linedt_2.text() self.run_config["server_port"] = self.server_port_linedt.text() # Temporary measure to check whether the workspace exists or not full_ws_path = os.path.join( Config["model.chroot"], self.run_config["workspace_name"] ) if not os.path.isdir(full_ws_path): logger.error( f"No workspace can be found at {full_ws_path}, Not starting SuRVoS." ) self.button_feedback_response( f"Workspace {self.run_config['workspace_name']} does not appear to exist!", self.run_button, "maroon", ) return # Try some fancy SSH stuff here if self.ssh_button.isChecked(): self.start_server_over_ssh() else: self.server_process = subprocess.Popen( [ "python", self.script_fullname, "start_server", self.run_config["workspace_name"], self.run_config["server_port"], ] ) try: outs, errs = self.server_process.communicate(timeout=10) print(f"OUTS: {outs, errs}") except subprocess.TimeoutExpired: pass self.start_client() def start_client(self): if not self.ssh_error: self.button_feedback_response( "Starting Client.", self.run_button, "green", 7 ) self.run_config["server_ip"] = self.server_ip_linedt.text() self.client_process = subprocess.Popen( [ "python", self.script_fullname, "nu_gui", self.run_config["workspace_name"], str(self.run_config["server_ip"]) + ":" + str(self.run_config["server_port"]), ] )