def add_dataset_file(self, dataset_file): dataset_name = os.path.basename(dataset_file) dataset_path = os.path.normpath(dataset_file) num_samples = [] dataset_data = Export.read_dataset_config(dataset_file) for key in dataset_data.samples: num_samples.append('{}={}'.format(key, dataset_data.samples[key])) num_samples = ', '.join(num_samples) dataset_format = Export.config('formats')[dataset_data.format] pos = self.dataset_file_table.rowCount() self.dataset_file_table.setRowCount(pos + 1) item_name = QtWidgets.QTableWidgetItem(dataset_name) item_name.setFlags(Qt.ItemIsEnabled) self.dataset_file_table.setItem(pos, 0, item_name) item_format = QtWidgets.QTableWidgetItem(dataset_format) item_format.setFlags(Qt.ItemIsEnabled) self.dataset_file_table.setItem(pos, 1, item_format) item_samples = QtWidgets.QTableWidgetItem(num_samples) item_samples.setFlags(Qt.ItemIsEnabled) self.dataset_file_table.setItem(pos, 2, item_samples) item_path = QtWidgets.QTableWidgetItem(dataset_path) item_path.setFlags(Qt.ItemIsEnabled) self.dataset_file_table.setItem(pos, 3, item_path)
def dataset_folder_browse_btn_clicked(self, mode='train'): ext_filter = False extension = Export.config('extensions')[self.selected_format] format_name = Export.config('formats')[self.selected_format] if extension != False: ext_filter = '{} {}({})'.format(format_name, _('files'), extension) project_folder = self.parent.settings.value('settings/project/folder', '') logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) dataset_folder = os.path.join( project_folder, self.parent._config['project_dataset_folder']) if ext_filter: dataset_folder_or_file, selected_filter = QtWidgets.QFileDialog.getOpenFileName( self, _('Select dataset file'), dataset_folder, ext_filter) else: dataset_folder_or_file = QtWidgets.QFileDialog.getExistingDirectory( self, _('Select dataset folder'), dataset_folder) if dataset_folder_or_file: dataset_folder_or_file = os.path.normpath(dataset_folder_or_file) key = Export.detectDatasetFormat(dataset_folder_or_file) logger.debug('Detected dataset format {} for directory {}'.format( key, dataset_folder_or_file)) if key is None: mb = QtWidgets.QMessageBox() mb.warning(self, _('Training'), _('Could not detect format of selected dataset')) return False return dataset_folder_or_file return False
def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle(_('Import dataset')) self.set_default_window_flags(self) self.setWindowModality(Qt.NonModal) layout = QtWidgets.QVBoxLayout() self.setLayout(layout) self.formats = QtWidgets.QComboBox() for key, val in Export.config('formats').items(): self.formats.addItem(val) self.formats.setCurrentIndex(0) self.formats.currentTextChanged.connect(self.on_format_change) self.selected_format = list(Export.config('formats').keys())[0] format_group = QtWidgets.QGroupBox() format_group.setTitle(_('Format')) format_group_layout = QtWidgets.QVBoxLayout() format_group.setLayout(format_group_layout) format_group_layout.addWidget(self.formats) layout.addWidget(format_group) self.data_folder = QtWidgets.QLineEdit() self.data_folder.setReadOnly(True) data_browse_btn = QtWidgets.QPushButton(_('Browse')) data_browse_btn.clicked.connect(self.data_browse_btn_clicked) data_folder_group = QtWidgets.QGroupBox() data_folder_group.setTitle(_('Dataset folder')) data_folder_group_layout = QtWidgets.QHBoxLayout() data_folder_group.setLayout(data_folder_group_layout) data_folder_group_layout.addWidget(self.data_folder) data_folder_group_layout.addWidget(data_browse_btn) layout.addWidget(data_folder_group) self.output_folder = QtWidgets.QLineEdit() self.output_folder.setReadOnly(True) output_browse_btn = QtWidgets.QPushButton(_('Browse')) output_browse_btn.clicked.connect(self.output_browse_btn_clicked) output_folder_group = QtWidgets.QGroupBox() output_folder_group.setTitle(_('Output folder')) output_folder_group_layout = QtWidgets.QHBoxLayout() output_folder_group.setLayout(output_folder_group_layout) output_folder_group_layout.addWidget(self.output_folder) output_folder_group_layout.addWidget(output_browse_btn) layout.addWidget(output_folder_group) layout.addStretch() button_box = QtWidgets.QDialogButtonBox() export_btn = button_box.addButton( _('Import'), QtWidgets.QDialogButtonBox.AcceptRole) export_btn.clicked.connect(self.import_btn_clicked) cancel_btn = button_box.addButton( _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole) cancel_btn.clicked.connect(self.cancel_btn_clicked) layout.addWidget(button_box)
def on_format_change(self, value): formats = Export.config('formats') inv_formats = Export.invertDict(formats) if value in inv_formats: self.selected_format = inv_formats[value] logger.debug('Selected dataset format: {}'.format( self.selected_format)) else: logger.debug('Dataset format not found: {}'.format(value))
def export_before_training(self): training_defaults = self.parent._config['training_defaults'] selected_format = training_defaults['dataset_format'] dataset_name = replace_special_chars(self.dataset_name.text()) data_folder = None if self.parent.lastOpenDir is not None: data_folder = self.parent.lastOpenDir if data_folder is None: mb = QtWidgets.QMessageBox() mb.warning(self, _('Training'), _('Please open a folder with images first')) return all_labels = [] for i in range(len(self.parent.uniqLabelList)): all_labels.append(self.parent.uniqLabelList.item(i).text()) if len(all_labels) == 0: mb = QtWidgets.QMessageBox() mb.warning(self, _('Training'), _('No labels found in dataset')) return project_folder = self.parent.settings.value('settings/project/folder', '') project_dataset_folder = self.parent._config['project_dataset_folder'] logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) export_folder = os.path.join(project_folder, project_dataset_folder) validation_ratio = int(self.validation.text()) / 100.0 self.dataset_export_data = { 'dataset_name': dataset_name, 'format': selected_format, 'output_folder': os.path.join(export_folder, dataset_name), 'validation_ratio': validation_ratio, } self.selected_format = selected_format data = { 'data_folder': data_folder, 'export_folder': export_folder, 'selected_labels': all_labels, 'validation_ratio': validation_ratio, 'dataset_name': dataset_name, 'max_num_labels': Export.config('limits')['max_num_labels'], 'selected_format': Export.config('formats')[selected_format], } # Execution executor = ExportExecutor(data) self.run_thread(executor, self.start_training)
def update_config_info(self, config): if 'network' in config: network_name = config['network'] if network_name in Training.config('networks'): network_name = Training.config('networks')[network_name] self.config_network_value.setText(network_name) if 'training_name' in config['args']: self.config_training_name_value.setText( str(config['args']['training_name'])) if 'epochs' in config['args']: self.config_epochs_value.setText(str(config['args']['epochs'])) if 'last_epoch' in config: self.config_trained_epochs_value.setText(str(config['last_epoch'])) if 'early_stop_epochs' in config['args']: self.config_early_stop_value.setText( str(config['args']['early_stop_epochs'])) if 'batch_size' in config['args']: self.config_batch_size_value.setText( str(config['args']['batch_size'])) if 'learning_rate' in config['args']: self.config_learning_rate_value.setText( str(config['args']['learning_rate'])) if 'dataset' in config: dataset_format = config['dataset'] if dataset_format in Export.config('formats'): dataset_format = Export.config('formats')[dataset_format] self.config_dataset_format_value.setText(dataset_format) project_folder = self.parent.parent.settings.value( 'settings/project/folder', '') logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) if 'train_dataset' in config['args']: train_dataset = config['args']['train_dataset'] prefix = os.path.commonprefix([train_dataset, project_folder]) if len(prefix) >= len(project_folder): train_dataset = _( 'Project folder') + train_dataset[len(prefix):] self.config_dataset_train_value.setText(train_dataset) if 'validate_dataset' in config['args']: validate_dataset = config['args']['validate_dataset'] prefix = os.path.commonprefix([validate_dataset, project_folder]) if len(prefix) >= len(project_folder): validate_dataset = _( 'Project folder') + validate_dataset[len(prefix):] self.config_dataset_val_value.setText(validate_dataset)
def dataset_files_browse_btn_clicked(self): # TODO: Replace config_file_extension filters = _('Dataset files') + ' (*{})'.format(Export.config('config_file_extension')) dataset_files, selected_filter = QtWidgets.QFileDialog.getOpenFileNames(self, _('Select dataset files'), '', filters) if len(dataset_files) > 0: for dataset_file in dataset_files: dataset_file = os.path.normpath(dataset_file) self.add_dataset_file(dataset_file)
def update_label_checkboxes(self): while self.label_parent_widget.layout().count(): child = self.label_parent_widget.layout().takeAt(0) if child.widget(): child.widget().deleteLater() max_num_labels = Export.config('limits')['max_num_labels'] for i, checkbox in enumerate(self.label_checkboxes): checkbox.setChecked(i < max_num_labels) self.label_parent_widget.layout().addWidget(checkbox)
def classes(self): if self._classes is None: try: label_file = os.path.join(self._root, Export.config('labels_file')) with open(label_file, 'r') as f: labels = [l.strip() for l in f.readlines()] self._classes = labels except AssertionError as e: raise RuntimeError("Class names must not contain {}".format(e)) return self._classes
def export_browse_btn_clicked(self): last_dir = self.parent.settings.value('merge/last_export_dir', '') logger.debug('Restored value "{}" for setting merge/last_export_dir'.format(last_dir)) # TODO: Replace config_file_extension filters = _('Dataset file') + ' (*{})'.format(Export.config('config_file_extension')) export_file, selected_filter = QtWidgets.QFileDialog.getSaveFileName(self, _('Save output file as'), last_dir, filters) if export_file: export_file = os.path.normpath(export_file) self.parent.settings.setValue('merge/last_export_dir', os.path.dirname(export_file)) self.export_file.setText(export_file)
def set_current_format(self, format_key): try: formats = Export.config('formats') format_text = formats[format_key] for i in range(self.networks.count()): text = self.formats.itemText(i) if text == format_text: self.formats.setCurrentIndex(i) break except Exception as e: logger.error(traceback.format_exc())
def data_browse_btn_clicked(self): ext_filter = False extension = Export.config('extensions')[self.selected_format] format_name = Export.config('formats')[self.selected_format] if extension != False: ext_filter = '{} {}({})'.format(format_name, _('files'), extension) project_folder = self.parent.settings.value('settings/project/folder', '') logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) dataset_folder = os.path.join( project_folder, self.parent._config['project_dataset_folder']) if ext_filter: import_file_or_dir, selected_filter = QtWidgets.QFileDialog.getOpenFileName( self, _('Select dataset file'), dataset_folder, ext_filter) else: import_file_or_dir = QtWidgets.QFileDialog.getExistingDirectory( self, _('Select dataset folder'), dataset_folder) if import_file_or_dir: import_file_or_dir = os.path.normpath(import_file_or_dir) self.parent.settings.setValue('import/last_data_dir', os.path.dirname(import_file_or_dir)) self.data_folder.setText(import_file_or_dir)
def export_btn_clicked(self): # Data data = { 'data_folder': self.data_folder.text(), 'export_folder': self.export_folder.text(), 'selected_labels': [x.text() for x in self.label_checkboxes if x.isChecked()], 'validation_ratio': int(self.validation.value()) / 100.0, 'dataset_name': re.sub(r'[^a-zA-Z0-9 _-]+', '', self.export_name.text()), 'max_num_labels': Export.config('limits')['max_num_labels'], 'selected_format': self.formats.currentText(), } # Execution executor = ExportExecutor(data) self.run_thread(executor, self.finish_export)
def __init__(self, all_image_sets=False): super().__init__() self.intermediate = None self.dataset = None self.all_image_sets = all_image_sets FormatVoc._files['labels'] = Export.config('labels_file')
def run(self): logger.debug('Prepare export') try: import ptvsd ptvsd.debug_this_thread() except: pass data_folder = self.data['data_folder'] is_data_folder_valid = True if not data_folder: is_data_folder_valid = False data_folder = os.path.normpath(data_folder) if not os.path.isdir(data_folder): is_data_folder_valid = False if not is_data_folder_valid: self.thread.message.emit(_('Export'), _('Please enter a valid data folder'), MessageType.Warning) self.abort() return export_folder = self.data['export_folder'] is_export_folder_valid = True if not export_folder: is_export_folder_valid = False export_folder = os.path.normpath(export_folder) if not os.path.isdir(export_folder): is_export_folder_valid = False if not is_export_folder_valid: self.thread.message.emit(_('Export'), _('Please enter a valid export folder'), MessageType.Warning) self.abort() return selected_labels = self.data['selected_labels'] num_selected_labels = len(selected_labels) limit = self.data['max_num_labels'] if num_selected_labels > limit: self.thread.message.emit( _('Export'), _('Please select a maximum of {} labels').format(limit), MessageType.Warning) self.abort() return elif num_selected_labels <= 0: self.thread.message.emit(_('Export'), _('Please select at least 1 label'), MessageType.Warning) self.abort() return dataset_name = replace_special_chars(self.data['dataset_name']) if not dataset_name: self.thread.message.emit(_('Export'), _('Please enter a valid dataset name'), MessageType.Warning) self.abort() return export_dataset_folder = os.path.normpath( os.path.join(self.data['export_folder'], self.data['dataset_name'])) if not os.path.isdir(export_dataset_folder): os.makedirs(export_dataset_folder) elif len(os.listdir(export_dataset_folder)) > 0: msg = _( 'The selected output directory "{}" is not empty. All containing files will be deleted. Are you sure to continue?' ).format(export_dataset_folder) if self.doConfirm(_('Export'), msg, MessageType.Warning): deltree(export_dataset_folder) time.sleep(0.5) # wait for deletion to be finished if not os.path.exists(export_dataset_folder): os.makedirs(export_dataset_folder) else: self.abort() return if not os.path.isdir(export_dataset_folder): self.thread.message.emit( _('Export'), _('The selected output directory "{}" could not be created'). format(export_dataset_folder), MessageType.Warning) self.abort() return selected_format = self.data['selected_format'] all_formats = Export.config('formats') inv_formats = Export.invertDict(all_formats) if selected_format not in inv_formats: self.thread.message.emit( _('Export'), _('Export format {} could not be found').format( selected_format), MessageType.Warning) self.abort() return else: self.data['format_name'] = inv_formats[selected_format] logger.debug('Start export') selected_labels = self.data['selected_labels'] validation_ratio = self.data['validation_ratio'] data_folder = self.data['data_folder'] format_name = self.data['format_name'] self.checkAborted() intermediate = IntermediateFormat() intermediate.setAbortable(self.abortable) intermediate.setThread(self.thread) intermediate.setIncludedLabels(selected_labels) intermediate.setValidationRatio(validation_ratio) intermediate.addFromLabelFiles(data_folder, shuffle=False) self.thread.update.emit(_('Loading data ...'), 0, intermediate.getNumberOfSamples() + 5) args = Map({ 'validation_ratio': validation_ratio, }) dataset_format = Export.config('objects')[format_name]() dataset_format.setAbortable(self.abortable) dataset_format.setThread(self.thread) dataset_format.setIntermediateFormat(intermediate) dataset_format.setInputFolderOrFile(data_folder) dataset_format.setOutputFolder(export_dataset_folder) dataset_format.setArgs(args) self.checkAborted() dataset_format.export()
def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle(_('Export dataset')) self.set_default_window_flags(self) self.setWindowModality(Qt.NonModal) layout = QtWidgets.QVBoxLayout() self.setLayout(layout) self.formats = QtWidgets.QComboBox() for key, val in Export.config('formats').items(): self.formats.addItem(val) format_group = QtWidgets.QGroupBox() format_group.setTitle(_('Format')) format_group_layout = QtWidgets.QVBoxLayout() format_group.setLayout(format_group_layout) format_group_layout.addWidget(self.formats) layout.addWidget(format_group) self.data_folder = QtWidgets.QLineEdit() self.data_folder.setReadOnly(True) data_folder_group = QtWidgets.QGroupBox() data_folder_group.setTitle(_('Data folder')) data_folder_group_layout = QtWidgets.QHBoxLayout() data_folder_group.setLayout(data_folder_group_layout) data_folder_group_layout.addWidget(self.data_folder) layout.addWidget(data_folder_group) self.label_checkboxes = [] self.label_selection_label = QtWidgets.QLabel(_('Label selection')) self.label_selection_label.setVisible(False) layout.addWidget(self.label_selection_label) self.label_parent_widget = QtWidgets.QWidget() self.label_parent_widget.setLayout(QtWidgets.QVBoxLayout()) self.label_selection_scroll = QtWidgets.QScrollArea() self.label_selection_scroll.setVisible(False) self.label_selection_scroll.setWidgetResizable(True) self.label_selection_scroll.setMinimumHeight(40) self.label_selection_scroll.setMaximumHeight(150) self.label_selection_scroll.setWidget(self.label_parent_widget) layout.addWidget(self.label_selection_scroll) labels = self.get_labels() if self.parent.lastOpenDir is not None and len(labels) > 0: self.data_folder.setText(self.parent.lastOpenDir) self.load_labels(labels) self.export_folder = QtWidgets.QLineEdit() project_folder = self.parent.settings.value('settings/project/folder', '') logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) self.export_folder.setText( os.path.join(project_folder, self.parent._config['project_dataset_folder'])) self.export_folder.setReadOnly(True) export_browse_btn = QtWidgets.QPushButton(_('Browse')) export_browse_btn.clicked.connect(self.export_browse_btn_clicked) export_name_label = QtWidgets.QLabel(_('Dataset name')) self.export_name = QtWidgets.QLineEdit() export_folder_group = QtWidgets.QGroupBox() export_folder_group.setTitle(_('Export folder')) export_folder_group_layout = QtWidgets.QGridLayout() export_folder_group.setLayout(export_folder_group_layout) export_folder_group_layout.addWidget(self.export_folder, 0, 0, 1, 2) export_folder_group_layout.addWidget(export_browse_btn, 0, 2) export_folder_group_layout.addWidget(export_name_label, 1, 0, 1, 3) export_folder_group_layout.addWidget(self.export_name, 2, 0, 1, 3) layout.addWidget(export_folder_group) self.validation = QtWidgets.QSpinBox() self.validation.setValue(0) self.validation.setMinimum(0) self.validation.setMaximum(90) self.validation.setFixedWidth(50) validation_label = QtWidgets.QLabel(_('% of dataset')) validation_group = QtWidgets.QGroupBox() validation_group.setTitle(_('Validation ratio')) validation_group_layout = QtWidgets.QGridLayout() validation_group.setLayout(validation_group_layout) validation_group_layout.addWidget(self.validation, 0, 0) validation_group_layout.addWidget(validation_label, 0, 1) layout.addWidget(validation_group) layout.addStretch() button_box = QtWidgets.QDialogButtonBox() export_btn = button_box.addButton( _('Export'), QtWidgets.QDialogButtonBox.AcceptRole) export_btn.clicked.connect(self.export_btn_clicked) cancel_btn = button_box.addButton( _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole) cancel_btn.clicked.connect(self.cancel_btn_clicked) layout.addWidget(button_box)
def start_training(self): training_defaults = self.parent._config['training_defaults'] network_key = self.get_current_network_key() epochs = self.args_epochs.value() resume_training_file = '' resume_epoch = 0 if self.resume_training_checkbox.widget.isChecked( ) and self.resume_file.count() > 0: idx = self.resume_file.currentIndex() resume_training_file, resume_epoch = self.resume_file.itemData(idx) epochs += resume_epoch # Data data = { 'create_dataset': self.create_dataset_checkbox.widget.isChecked(), 'resume_training': resume_training_file, 'start_epoch': resume_epoch, 'dataset_export_data': self.dataset_export_data, 'train_dataset': self.train_dataset_folder.text(), 'val_dataset': self.val_dataset_folder.text(), 'output_folder': self.output_folder.text(), 'selected_format': self.selected_format, 'training_name': self.training_name.text(), 'network': network_key, 'gpu_checkboxes': self.gpu_checkboxes, 'args_epochs': epochs, 'args_batch_size': self.args_batch_size.value(), 'args_learning_rate': self.args_learning_rate.value(), 'args_early_stop_epochs': self.args_early_stop_epochs.value(), } # Preprocess data mb = QtWidgets.QMessageBox() create_dataset = data['create_dataset'] if create_dataset: export_data = data['dataset_export_data'] format_name = export_data['format'] dataset_format = Export.config('objects')[format_name]() output_folder = export_data['output_folder'] train_file = dataset_format.getOutputFileName('train') train_dataset = os.path.join(output_folder, train_file) data['train_dataset'] = train_dataset validation_ratio = export_data['validation_ratio'] if validation_ratio > 0: val_file = dataset_format.getOutputFileName('val') val_dataset = os.path.join(output_folder, val_file) else: # Validation dataset is optional val_dataset = False data['val_dataset'] = val_dataset else: train_dataset = data['train_dataset'] is_train_dataset_valid = True if not train_dataset: is_train_dataset_valid = False train_dataset = os.path.normpath(train_dataset) if not (os.path.isdir(train_dataset) or os.path.isfile(train_dataset)): is_train_dataset_valid = False if not is_train_dataset_valid: mb.warning(self, _('Training'), _('Please select a valid training dataset')) return data['train_dataset'] = train_dataset val_dataset = data['val_dataset'] is_val_dataset_valid = True if not val_dataset: is_val_dataset_valid = False val_dataset = os.path.normpath(val_dataset) if not (os.path.isdir(val_dataset) or os.path.isfile(val_dataset)): is_val_dataset_valid = False if not is_val_dataset_valid: # Validation dataset is optional val_dataset = False data['val_dataset'] = val_dataset if val_dataset and val_dataset == train_dataset: mb.warning( self, _('Training'), _('Training and validation dataset are equal. Please use different datasets, as validation results are useless otherwise.' )) return output_folder = os.path.normpath(data['output_folder']) training_name = data['training_name'] training_name = replace_special_chars(training_name) data['training_name'] = training_name if not training_name: mb.warning(self, _('Training'), _('Please enter a valid training name')) return output_folder = os.path.join(output_folder, training_name) data['output_folder'] = output_folder if not os.path.isdir(output_folder): os.makedirs(output_folder) elif len(os.listdir(output_folder)) > 0 and resume_training_file: mb.warning( self, _('Training'), _('The selected output directory "{}" is not empty. Please select a different directory.' ).format(output_folder)) return elif len(os.listdir(output_folder)) > 0 and not resume_training_file: msg = _( 'The selected output directory "{}" is not empty. All containing files will be deleted. Are you sure to continue?' ).format(output_folder) result = confirm(self, _('Training'), msg, MessageType.Warning) if result: deltree(output_folder) time.sleep(0.5) # wait for deletion to be finished if not os.path.exists(output_folder): os.makedirs(output_folder) else: return if not os.path.isdir(output_folder): mb.warning( self, _('Training'), _('The selected output directory "{}" could not be created'). format(output_folder)) return # Open new window for training progress trainingWin = TrainingProgressWindow(self) trainingWin.show() trainingWin.start_training(data) self.close()
def __init__(self): super().__init__() self.intermediate = None self.num_samples = -1 FormatImageRecord._files['labels'] = Export.config('labels_file')
def run(self): logger.debug('Prepare import') try: import ptvsd ptvsd.debug_this_thread() except: pass data_folder_or_file = self.data['data_folder'] is_data_folder_valid = True if not data_folder_or_file: is_data_folder_valid = False data_folder_or_file = os.path.normpath(data_folder_or_file) if not (os.path.isdir(data_folder_or_file) or os.path.isfile(data_folder_or_file)): is_data_folder_valid = False if not is_data_folder_valid: self.thread.message.emit( _('Import'), _('Please enter a valid dataset file or folder'), MessageType.Warning) self.abort() return output_folder = self.data['output_folder'] is_output_folder_valid = True if not output_folder: is_output_folder_valid = False output_folder = os.path.normpath(output_folder) if not os.path.isdir(output_folder): is_output_folder_valid = False if not is_output_folder_valid: self.thread.message.emit(_('Import'), _('Please enter a valid output folder'), MessageType.Warning) self.abort() return selected_format = self.data['selected_format'] all_formats = Export.config('formats') inv_formats = Export.invertDict(all_formats) if selected_format not in inv_formats: self.thread.message.emit( _('Import'), _('Import format {} could not be found').format( selected_format), MessageType.Warning) self.abort() return else: self.data['format_name'] = inv_formats[selected_format] format_name = self.data['format_name'] # Dataset dataset_format = Export.config('objects')[format_name]() if not dataset_format.isValidFormat(data_folder_or_file): self.thread.message.emit(_('Import'), _('Invalid dataset format'), MessageType.Warning) self.abort() return dataset_format.setAbortable(self.abortable) dataset_format.setThread(self.thread) dataset_format.setOutputFolder(output_folder) dataset_format.setInputFolderOrFile(data_folder_or_file) self.checkAborted() dataset_format.importFolder()
def __init__(self, parent=None): super().__init__(parent) self.setWindowTitle(_('Training')) self.set_default_window_flags(self) self.setWindowModality(Qt.NonModal) self.dataset_format_init = False project_folder = self.parent.settings.value('settings/project/folder', '') logger.debug( 'Restored value "{}" for setting settings/project/folder'.format( project_folder)) layout = QtWidgets.QVBoxLayout() self.setLayout(layout) self.tabs = QtWidgets.QTabWidget() layout.addWidget(self.tabs) tab_dataset = QtWidgets.QWidget() tab_dataset_layout = QtWidgets.QVBoxLayout() tab_dataset.setLayout(tab_dataset_layout) self.tabs.addTab(tab_dataset, _('Dataset')) tab_network = QtWidgets.QWidget() tab_network_layout = QtWidgets.QVBoxLayout() tab_network.setLayout(tab_network_layout) self.tabs.addTab(tab_network, _('Network')) tab_resume = QtWidgets.QWidget() tab_resume_layout = QtWidgets.QVBoxLayout() tab_resume.setLayout(tab_resume_layout) self.tabs.addTab(tab_resume, _('Resume training')) # Network Tab self.networks = QtWidgets.QComboBox() for key, val in Training.config('networks').items(): self.networks.addItem(val) self.networks.currentIndexChanged.connect( self.network_selection_changed) network_group = HelpGroupBox('Training_NetworkArchitecture', _('Network')) network_group_layout = QtWidgets.QVBoxLayout() network_group.widget.setLayout(network_group_layout) network_group_layout.addWidget(self.networks) tab_network_layout.addWidget(network_group) training_defaults = self.parent._config['training_defaults'] network = self.get_current_network() args_epochs_label = HelpLabel('Training_SettingsEpochs', _('Epochs')) self.args_epochs = QtWidgets.QSpinBox() self.args_epochs.setMinimum(1) self.args_epochs.setMaximum(500) self.args_epochs.setValue(training_defaults['epochs']) default_batch_size = self.get_default_batch_size(network) args_batch_size_label = HelpLabel('Training_SettingsBatchSize', _('Batch size')) self.args_batch_size = QtWidgets.QSpinBox() self.args_batch_size.setMinimum(1) self.args_batch_size.setMaximum(100) self.args_batch_size.setValue(default_batch_size) default_learning_rate = self.get_default_learning_rate(network) args_learning_rate_label = HelpLabel('Training_SettingsLearningRate', _('Learning rate')) self.args_learning_rate = QtWidgets.QDoubleSpinBox() self.args_learning_rate.setMinimum(1e-7) self.args_learning_rate.setMaximum(1.0) self.args_learning_rate.setSingleStep(1e-7) self.args_learning_rate.setDecimals(7) self.args_learning_rate.setValue(default_learning_rate) args_early_stop_epochs_label = HelpLabel('Training_SettingsEarlyStop', _('Early stop epochs')) self.args_early_stop_epochs = QtWidgets.QSpinBox() self.args_early_stop_epochs.setMinimum(0) self.args_early_stop_epochs.setMaximum(100) self.args_early_stop_epochs.setValue( training_defaults['early_stop_epochs']) self.gpu_label_text = _('GPU') args_gpus_label = QtWidgets.QLabel(_('GPUs')) no_gpus_available_label = QtWidgets.QLabel(_('No GPUs available')) self.gpus = mx.test_utils.list_gpus() # ['0'] self.gpu_checkboxes = [] for i in self.gpus: checkbox = QtWidgets.QCheckBox('{} {}'.format( self.gpu_label_text, i)) checkbox.setChecked(i == 0) self.gpu_checkboxes.append(checkbox) settings_group = QtWidgets.QGroupBox() settings_group.setTitle(_('Settings')) settings_group_layout = QtWidgets.QGridLayout() settings_group.setLayout(settings_group_layout) settings_group_layout.addWidget(args_epochs_label, 0, 0) settings_group_layout.addWidget(self.args_epochs, 0, 1) settings_group_layout.addWidget(args_batch_size_label, 1, 0) settings_group_layout.addWidget(self.args_batch_size, 1, 1) settings_group_layout.addWidget(args_learning_rate_label, 2, 0) settings_group_layout.addWidget(self.args_learning_rate, 2, 1) settings_group_layout.addWidget(args_early_stop_epochs_label, 3, 0) settings_group_layout.addWidget(self.args_early_stop_epochs, 3, 1) settings_group_layout.addWidget(args_gpus_label, 4, 0) if len(self.gpu_checkboxes) > 0: row = 4 for i, checkbox in enumerate(self.gpu_checkboxes): settings_group_layout.addWidget(checkbox, row, 1) row += 1 else: settings_group_layout.addWidget(no_gpus_available_label, 4, 1) tab_network_layout.addWidget(settings_group) # Dataset Tab image_list = self.parent.imageList show_dataset_create = len(image_list) > 0 self.create_dataset_checkbox = HelpCheckbox( 'Training_CreateDataset', _('Create dataset from opened images')) self.create_dataset_checkbox.widget.setChecked(show_dataset_create) tab_dataset_layout.addWidget(self.create_dataset_checkbox) validation_label = HelpLabel('Training_ValidationRatio', _('Validation ratio')) self.validation = QtWidgets.QSpinBox() self.validation.setValue(10) self.validation.setMinimum(0) self.validation.setMaximum(90) self.validation.setFixedWidth(50) validation_description_label = QtWidgets.QLabel(_('% of dataset')) dataset_name_label = QtWidgets.QLabel(_('Dataset name')) self.dataset_name = QtWidgets.QLineEdit() self.create_dataset_group = QtWidgets.QGroupBox() self.create_dataset_group.setTitle(_('Create dataset')) create_dataset_group_layout = QtWidgets.QGridLayout() self.create_dataset_group.setLayout(create_dataset_group_layout) create_dataset_group_layout.addWidget(dataset_name_label, 0, 0, 1, 2) create_dataset_group_layout.addWidget(self.dataset_name, 1, 0, 1, 2) create_dataset_group_layout.addWidget(validation_label, 2, 0, 1, 2) create_dataset_group_layout.addWidget(self.validation, 3, 0) create_dataset_group_layout.addWidget(validation_description_label, 3, 1) formats_label = HelpLabel('Training_DatasetFormat', _('Format')) self.formats = QtWidgets.QComboBox() for key, val in Export.config('formats').items(): self.formats.addItem(val) self.formats.setCurrentIndex(0) self.formats.currentTextChanged.connect(self.on_format_change) self.selected_format = list(Export.config('formats').keys())[0] train_dataset_label = HelpLabel('Training_TrainingDataset', _('Training dataset')) self.train_dataset_folder = QtWidgets.QLineEdit() train_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse')) train_dataset_folder_browse_btn.clicked.connect( self.train_dataset_folder_browse_btn_clicked) val_label_text = '{} ({})'.format(_('Validation dataset'), _('optional')) val_dataset_label = HelpLabel('Training_ValidationDataset', val_label_text) self.val_dataset_folder = QtWidgets.QLineEdit() val_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse')) val_dataset_folder_browse_btn.clicked.connect( self.val_dataset_folder_browse_btn_clicked) self.dataset_folder_group = QtWidgets.QGroupBox() self.dataset_folder_group.setTitle(_('Use dataset file(s)')) dataset_folder_group_layout = QtWidgets.QGridLayout() self.dataset_folder_group.setLayout(dataset_folder_group_layout) dataset_folder_group_layout.addWidget(formats_label, 0, 0, 1, 2) dataset_folder_group_layout.addWidget(self.formats, 1, 0, 1, 2) dataset_folder_group_layout.addWidget(train_dataset_label, 2, 0, 1, 2) dataset_folder_group_layout.addWidget(self.train_dataset_folder, 3, 0) dataset_folder_group_layout.addWidget(train_dataset_folder_browse_btn, 3, 1) dataset_folder_group_layout.addWidget(val_dataset_label, 4, 0, 1, 2) dataset_folder_group_layout.addWidget(self.val_dataset_folder, 5, 0) dataset_folder_group_layout.addWidget(val_dataset_folder_browse_btn, 5, 1) tab_dataset_layout.addWidget(self.create_dataset_group) tab_dataset_layout.addWidget(self.dataset_folder_group) if show_dataset_create: self.dataset_folder_group.hide() else: self.create_dataset_checkbox.hide() self.create_dataset_group.hide() self.create_dataset_checkbox.widget.toggled.connect( lambda: self.switch_visibility(self.create_dataset_group, self. dataset_folder_group)) self.output_folder = QtWidgets.QLineEdit() self.output_folder.setText( os.path.join(project_folder, self.parent._config['project_training_folder'])) # self.output_folder.setReadOnly(True) # output_browse_btn = QtWidgets.QPushButton(_('Browse')) # output_browse_btn.clicked.connect(self.output_browse_btn_clicked) training_name_label = HelpLabel('Training_TrainingName', _('Training name')) self.training_name = QtWidgets.QLineEdit() output_folder_group = QtWidgets.QGroupBox() output_folder_group.setTitle(_('Output folder')) output_folder_group_layout = QtWidgets.QGridLayout() output_folder_group.setLayout(output_folder_group_layout) # output_folder_group_layout.addWidget(self.output_folder, 0, 0, 1, 2) # output_folder_group_layout.addWidget(output_browse_btn, 0, 2) output_folder_group_layout.addWidget(training_name_label, 1, 0, 1, 3) output_folder_group_layout.addWidget(self.training_name, 2, 0, 1, 3) tab_dataset_layout.addWidget(output_folder_group) # Resume Tab self.resume_training_checkbox = HelpCheckbox( 'Training_Resume', _('Resume previous training')) self.resume_training_checkbox.widget.setChecked(False) tab_resume_layout.addWidget(self.resume_training_checkbox) self.resume_group = QtWidgets.QWidget() resume_group_layout = QtWidgets.QVBoxLayout() self.resume_group.setLayout(resume_group_layout) tab_resume_layout.addWidget(self.resume_group) self.resume_group.hide() self.resume_training_checkbox.widget.toggled.connect( self.toggle_resume_training_checkbox) self.resume_folder = QtWidgets.QLineEdit() self.resume_folder.setText('') self.resume_folder.setReadOnly(True) resume_browse_btn = QtWidgets.QPushButton(_('Browse')) resume_browse_btn.clicked.connect(self.resume_browse_btn_clicked) resume_folder_group = QtWidgets.QGroupBox() resume_folder_group.setTitle(_('Training directory')) resume_folder_group_layout = QtWidgets.QHBoxLayout() resume_folder_group.setLayout(resume_folder_group_layout) resume_folder_group_layout.addWidget(self.resume_folder) resume_folder_group_layout.addWidget(resume_browse_btn) resume_group_layout.addWidget(resume_folder_group) self.resume_file = QtWidgets.QComboBox() self.resume_file_group = QtWidgets.QGroupBox() self.resume_file_group.setTitle(_('Resume after')) resume_file_group_layout = QtWidgets.QHBoxLayout() self.resume_file_group.setLayout(resume_file_group_layout) resume_file_group_layout.addWidget(self.resume_file) resume_group_layout.addWidget(self.resume_file_group) self.resume_file_group.hide() tab_dataset_layout.addStretch() tab_network_layout.addStretch() tab_resume_layout.addStretch() button_box = QtWidgets.QDialogButtonBox() self.training_btn = button_box.addButton( _('Start Training'), QtWidgets.QDialogButtonBox.AcceptRole) self.training_btn.clicked.connect(self.training_btn_clicked) cancel_btn = button_box.addButton( _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole) cancel_btn.clicked.connect(self.cancel_btn_clicked) layout.addWidget(button_box) h = self.sizeHint().height() self.resize(500, h)
def __init__(self): super().__init__() self.intermediate = None self.dataset = None self.num_samples = -1 FormatCoco._files['labels'] = Export.config('labels_file')
def run(self): logger.debug('Prepare training') try: import ptvsd ptvsd.debug_this_thread() except: pass network_key = self.data['network'] if network_key not in Training.config('objects'): self.thread.message.emit( _('Training'), _('Network {} could not be found').format(network_key), MessageType.Error) self.abort() return # Training settings gpus = [] gpu_checkboxes = self.data['gpu_checkboxes'] for i, gpu in enumerate(gpu_checkboxes): if gpu.checkState() == Qt.Checked: gpus.append(str(i)) gpus = ','.join(gpus) epochs = int(self.data['args_epochs']) batch_size = int(self.data['args_batch_size']) # Dataset dataset_format = self.data['selected_format'] train_dataset_obj = Export.config('objects')[dataset_format]() train_dataset_obj.setInputFolderOrFile(self.data['train_dataset']) if self.data['val_dataset']: val_dataset_obj = Export.config('objects')[dataset_format]() val_dataset_obj.setInputFolderOrFile(self.data['val_dataset']) labels = train_dataset_obj.getLabels() num_train_samples = train_dataset_obj.getNumSamples() num_batches = int(math.ceil(num_train_samples / batch_size)) args = Map({ 'network': self.data['network'], 'train_dataset': self.data['train_dataset'], 'validate_dataset': self.data['val_dataset'], 'training_name': self.data['training_name'], 'batch_size': batch_size, 'learning_rate': float(self.data['args_learning_rate']), 'gpus': gpus, 'epochs': epochs, 'early_stop_epochs': int(self.data['args_early_stop_epochs']), 'start_epoch': self.data['start_epoch'], 'resume': self.data['resume_training'], }) self.thread.update.emit(_('Loading data ...'), 0, epochs * num_batches + 5) with Training.config('objects')[network_key]() as network: network.setAbortable(self.abortable) network.setThread(self.thread) network.setArgs(args) network.setOutputFolder(self.data['output_folder']) network.setTrainDataset(train_dataset_obj, dataset_format) network.setLabels(labels) if self.data['val_dataset']: network.setValDataset(val_dataset_obj) self.checkAborted() network.training()