Beispiel #1
0
    def add_dataset_file(self, dataset_file):
        dataset_name = os.path.basename(dataset_file)
        dataset_path = os.path.normpath(dataset_file)
        num_samples = []
        dataset_data = Export.read_dataset_config(dataset_file)
        for key in dataset_data.samples:
            num_samples.append('{}={}'.format(key, dataset_data.samples[key]))
        num_samples = ', '.join(num_samples)
        dataset_format = Export.config('formats')[dataset_data.format]

        pos = self.dataset_file_table.rowCount()
        self.dataset_file_table.setRowCount(pos + 1)

        item_name = QtWidgets.QTableWidgetItem(dataset_name)
        item_name.setFlags(Qt.ItemIsEnabled)
        self.dataset_file_table.setItem(pos, 0, item_name)

        item_format = QtWidgets.QTableWidgetItem(dataset_format)
        item_format.setFlags(Qt.ItemIsEnabled)
        self.dataset_file_table.setItem(pos, 1, item_format)

        item_samples = QtWidgets.QTableWidgetItem(num_samples)
        item_samples.setFlags(Qt.ItemIsEnabled)
        self.dataset_file_table.setItem(pos, 2, item_samples)

        item_path = QtWidgets.QTableWidgetItem(dataset_path)
        item_path.setFlags(Qt.ItemIsEnabled)
        self.dataset_file_table.setItem(pos, 3, item_path)
 def dataset_folder_browse_btn_clicked(self, mode='train'):
     ext_filter = False
     extension = Export.config('extensions')[self.selected_format]
     format_name = Export.config('formats')[self.selected_format]
     if extension != False:
         ext_filter = '{} {}({})'.format(format_name, _('files'), extension)
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     dataset_folder = os.path.join(
         project_folder, self.parent._config['project_dataset_folder'])
     if ext_filter:
         dataset_folder_or_file, selected_filter = QtWidgets.QFileDialog.getOpenFileName(
             self, _('Select dataset file'), dataset_folder, ext_filter)
     else:
         dataset_folder_or_file = QtWidgets.QFileDialog.getExistingDirectory(
             self, _('Select dataset folder'), dataset_folder)
     if dataset_folder_or_file:
         dataset_folder_or_file = os.path.normpath(dataset_folder_or_file)
         key = Export.detectDatasetFormat(dataset_folder_or_file)
         logger.debug('Detected dataset format {} for directory {}'.format(
             key, dataset_folder_or_file))
         if key is None:
             mb = QtWidgets.QMessageBox()
             mb.warning(self, _('Training'),
                        _('Could not detect format of selected dataset'))
             return False
         return dataset_folder_or_file
     return False
Beispiel #3
0
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle(_('Import dataset'))
        self.set_default_window_flags(self)
        self.setWindowModality(Qt.NonModal)

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)

        self.formats = QtWidgets.QComboBox()
        for key, val in Export.config('formats').items():
            self.formats.addItem(val)
        self.formats.setCurrentIndex(0)
        self.formats.currentTextChanged.connect(self.on_format_change)
        self.selected_format = list(Export.config('formats').keys())[0]

        format_group = QtWidgets.QGroupBox()
        format_group.setTitle(_('Format'))
        format_group_layout = QtWidgets.QVBoxLayout()
        format_group.setLayout(format_group_layout)
        format_group_layout.addWidget(self.formats)
        layout.addWidget(format_group)

        self.data_folder = QtWidgets.QLineEdit()
        self.data_folder.setReadOnly(True)
        data_browse_btn = QtWidgets.QPushButton(_('Browse'))
        data_browse_btn.clicked.connect(self.data_browse_btn_clicked)

        data_folder_group = QtWidgets.QGroupBox()
        data_folder_group.setTitle(_('Dataset folder'))
        data_folder_group_layout = QtWidgets.QHBoxLayout()
        data_folder_group.setLayout(data_folder_group_layout)
        data_folder_group_layout.addWidget(self.data_folder)
        data_folder_group_layout.addWidget(data_browse_btn)
        layout.addWidget(data_folder_group)

        self.output_folder = QtWidgets.QLineEdit()
        self.output_folder.setReadOnly(True)
        output_browse_btn = QtWidgets.QPushButton(_('Browse'))
        output_browse_btn.clicked.connect(self.output_browse_btn_clicked)

        output_folder_group = QtWidgets.QGroupBox()
        output_folder_group.setTitle(_('Output folder'))
        output_folder_group_layout = QtWidgets.QHBoxLayout()
        output_folder_group.setLayout(output_folder_group_layout)
        output_folder_group_layout.addWidget(self.output_folder)
        output_folder_group_layout.addWidget(output_browse_btn)
        layout.addWidget(output_folder_group)

        layout.addStretch()

        button_box = QtWidgets.QDialogButtonBox()
        export_btn = button_box.addButton(
            _('Import'), QtWidgets.QDialogButtonBox.AcceptRole)
        export_btn.clicked.connect(self.import_btn_clicked)
        cancel_btn = button_box.addButton(
            _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole)
        cancel_btn.clicked.connect(self.cancel_btn_clicked)
        layout.addWidget(button_box)
 def on_format_change(self, value):
     formats = Export.config('formats')
     inv_formats = Export.invertDict(formats)
     if value in inv_formats:
         self.selected_format = inv_formats[value]
         logger.debug('Selected dataset format: {}'.format(
             self.selected_format))
     else:
         logger.debug('Dataset format not found: {}'.format(value))
    def export_before_training(self):
        training_defaults = self.parent._config['training_defaults']
        selected_format = training_defaults['dataset_format']
        dataset_name = replace_special_chars(self.dataset_name.text())

        data_folder = None
        if self.parent.lastOpenDir is not None:
            data_folder = self.parent.lastOpenDir
        if data_folder is None:
            mb = QtWidgets.QMessageBox()
            mb.warning(self, _('Training'),
                       _('Please open a folder with images first'))
            return

        all_labels = []
        for i in range(len(self.parent.uniqLabelList)):
            all_labels.append(self.parent.uniqLabelList.item(i).text())
        if len(all_labels) == 0:
            mb = QtWidgets.QMessageBox()
            mb.warning(self, _('Training'), _('No labels found in dataset'))
            return

        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        project_dataset_folder = self.parent._config['project_dataset_folder']
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))
        export_folder = os.path.join(project_folder, project_dataset_folder)

        validation_ratio = int(self.validation.text()) / 100.0

        self.dataset_export_data = {
            'dataset_name': dataset_name,
            'format': selected_format,
            'output_folder': os.path.join(export_folder, dataset_name),
            'validation_ratio': validation_ratio,
        }
        self.selected_format = selected_format

        data = {
            'data_folder': data_folder,
            'export_folder': export_folder,
            'selected_labels': all_labels,
            'validation_ratio': validation_ratio,
            'dataset_name': dataset_name,
            'max_num_labels': Export.config('limits')['max_num_labels'],
            'selected_format': Export.config('formats')[selected_format],
        }

        # Execution
        executor = ExportExecutor(data)
        self.run_thread(executor, self.start_training)
    def update_config_info(self, config):
        if 'network' in config:
            network_name = config['network']
            if network_name in Training.config('networks'):
                network_name = Training.config('networks')[network_name]
            self.config_network_value.setText(network_name)
        if 'training_name' in config['args']:
            self.config_training_name_value.setText(
                str(config['args']['training_name']))
        if 'epochs' in config['args']:
            self.config_epochs_value.setText(str(config['args']['epochs']))
        if 'last_epoch' in config:
            self.config_trained_epochs_value.setText(str(config['last_epoch']))
        if 'early_stop_epochs' in config['args']:
            self.config_early_stop_value.setText(
                str(config['args']['early_stop_epochs']))
        if 'batch_size' in config['args']:
            self.config_batch_size_value.setText(
                str(config['args']['batch_size']))
        if 'learning_rate' in config['args']:
            self.config_learning_rate_value.setText(
                str(config['args']['learning_rate']))
        if 'dataset' in config:
            dataset_format = config['dataset']
            if dataset_format in Export.config('formats'):
                dataset_format = Export.config('formats')[dataset_format]
            self.config_dataset_format_value.setText(dataset_format)

        project_folder = self.parent.parent.settings.value(
            'settings/project/folder', '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        if 'train_dataset' in config['args']:
            train_dataset = config['args']['train_dataset']
            prefix = os.path.commonprefix([train_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                train_dataset = _(
                    'Project folder') + train_dataset[len(prefix):]
            self.config_dataset_train_value.setText(train_dataset)
        if 'validate_dataset' in config['args']:
            validate_dataset = config['args']['validate_dataset']
            prefix = os.path.commonprefix([validate_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                validate_dataset = _(
                    'Project folder') + validate_dataset[len(prefix):]
            self.config_dataset_val_value.setText(validate_dataset)
Beispiel #7
0
 def dataset_files_browse_btn_clicked(self):
     # TODO: Replace config_file_extension
     filters = _('Dataset files') + ' (*{})'.format(Export.config('config_file_extension'))
     dataset_files, selected_filter = QtWidgets.QFileDialog.getOpenFileNames(self, _('Select dataset files'), '', filters)
     if len(dataset_files) > 0:
         for dataset_file in dataset_files:
             dataset_file = os.path.normpath(dataset_file)
             self.add_dataset_file(dataset_file)
Beispiel #8
0
 def update_label_checkboxes(self):
     while self.label_parent_widget.layout().count():
         child = self.label_parent_widget.layout().takeAt(0)
         if child.widget():
             child.widget().deleteLater()
     max_num_labels = Export.config('limits')['max_num_labels']
     for i, checkbox in enumerate(self.label_checkboxes):
         checkbox.setChecked(i < max_num_labels)
         self.label_parent_widget.layout().addWidget(checkbox)
Beispiel #9
0
 def classes(self):
     if self._classes is None:
         try:
             label_file = os.path.join(self._root, Export.config('labels_file'))
             with open(label_file, 'r') as f:
                 labels = [l.strip() for l in f.readlines()]
             self._classes = labels
         except AssertionError as e:
             raise RuntimeError("Class names must not contain {}".format(e))
     return self._classes
Beispiel #10
0
 def export_browse_btn_clicked(self):
     last_dir = self.parent.settings.value('merge/last_export_dir', '')
     logger.debug('Restored value "{}" for setting merge/last_export_dir'.format(last_dir))
     # TODO: Replace config_file_extension
     filters = _('Dataset file') + ' (*{})'.format(Export.config('config_file_extension'))
     export_file, selected_filter = QtWidgets.QFileDialog.getSaveFileName(self, _('Save output file as'), last_dir, filters)
     if export_file:
         export_file = os.path.normpath(export_file)
         self.parent.settings.setValue('merge/last_export_dir', os.path.dirname(export_file))
         self.export_file.setText(export_file)
 def set_current_format(self, format_key):
     try:
         formats = Export.config('formats')
         format_text = formats[format_key]
         for i in range(self.networks.count()):
             text = self.formats.itemText(i)
             if text == format_text:
                 self.formats.setCurrentIndex(i)
                 break
     except Exception as e:
         logger.error(traceback.format_exc())
Beispiel #12
0
 def data_browse_btn_clicked(self):
     ext_filter = False
     extension = Export.config('extensions')[self.selected_format]
     format_name = Export.config('formats')[self.selected_format]
     if extension != False:
         ext_filter = '{} {}({})'.format(format_name, _('files'), extension)
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     dataset_folder = os.path.join(
         project_folder, self.parent._config['project_dataset_folder'])
     if ext_filter:
         import_file_or_dir, selected_filter = QtWidgets.QFileDialog.getOpenFileName(
             self, _('Select dataset file'), dataset_folder, ext_filter)
     else:
         import_file_or_dir = QtWidgets.QFileDialog.getExistingDirectory(
             self, _('Select dataset folder'), dataset_folder)
     if import_file_or_dir:
         import_file_or_dir = os.path.normpath(import_file_or_dir)
         self.parent.settings.setValue('import/last_data_dir',
                                       os.path.dirname(import_file_or_dir))
         self.data_folder.setText(import_file_or_dir)
Beispiel #13
0
    def export_btn_clicked(self):
        # Data
        data = {
            'data_folder':
            self.data_folder.text(),
            'export_folder':
            self.export_folder.text(),
            'selected_labels':
            [x.text() for x in self.label_checkboxes if x.isChecked()],
            'validation_ratio':
            int(self.validation.value()) / 100.0,
            'dataset_name':
            re.sub(r'[^a-zA-Z0-9 _-]+', '', self.export_name.text()),
            'max_num_labels':
            Export.config('limits')['max_num_labels'],
            'selected_format':
            self.formats.currentText(),
        }

        # Execution
        executor = ExportExecutor(data)
        self.run_thread(executor, self.finish_export)
Beispiel #14
0
 def __init__(self, all_image_sets=False):
     super().__init__()
     self.intermediate = None
     self.dataset = None
     self.all_image_sets = all_image_sets
     FormatVoc._files['labels'] = Export.config('labels_file')
Beispiel #15
0
    def run(self):
        logger.debug('Prepare export')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        data_folder = self.data['data_folder']
        is_data_folder_valid = True
        if not data_folder:
            is_data_folder_valid = False
        data_folder = os.path.normpath(data_folder)
        if not os.path.isdir(data_folder):
            is_data_folder_valid = False
        if not is_data_folder_valid:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid data folder'),
                                     MessageType.Warning)
            self.abort()
            return

        export_folder = self.data['export_folder']
        is_export_folder_valid = True
        if not export_folder:
            is_export_folder_valid = False
        export_folder = os.path.normpath(export_folder)
        if not os.path.isdir(export_folder):
            is_export_folder_valid = False
        if not is_export_folder_valid:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid export folder'),
                                     MessageType.Warning)
            self.abort()
            return

        selected_labels = self.data['selected_labels']
        num_selected_labels = len(selected_labels)
        limit = self.data['max_num_labels']
        if num_selected_labels > limit:
            self.thread.message.emit(
                _('Export'),
                _('Please select a maximum of {} labels').format(limit),
                MessageType.Warning)
            self.abort()
            return
        elif num_selected_labels <= 0:
            self.thread.message.emit(_('Export'),
                                     _('Please select at least 1 label'),
                                     MessageType.Warning)
            self.abort()
            return

        dataset_name = replace_special_chars(self.data['dataset_name'])
        if not dataset_name:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid dataset name'),
                                     MessageType.Warning)
            self.abort()
            return

        export_dataset_folder = os.path.normpath(
            os.path.join(self.data['export_folder'],
                         self.data['dataset_name']))
        if not os.path.isdir(export_dataset_folder):
            os.makedirs(export_dataset_folder)
        elif len(os.listdir(export_dataset_folder)) > 0:
            msg = _(
                'The selected output directory "{}" is not empty. All containing files will be deleted. Are you sure to continue?'
            ).format(export_dataset_folder)
            if self.doConfirm(_('Export'), msg, MessageType.Warning):
                deltree(export_dataset_folder)
                time.sleep(0.5)  # wait for deletion to be finished
                if not os.path.exists(export_dataset_folder):
                    os.makedirs(export_dataset_folder)
            else:
                self.abort()
                return

        if not os.path.isdir(export_dataset_folder):
            self.thread.message.emit(
                _('Export'),
                _('The selected output directory "{}" could not be created').
                format(export_dataset_folder), MessageType.Warning)
            self.abort()
            return

        selected_format = self.data['selected_format']
        all_formats = Export.config('formats')
        inv_formats = Export.invertDict(all_formats)
        if selected_format not in inv_formats:
            self.thread.message.emit(
                _('Export'),
                _('Export format {} could not be found').format(
                    selected_format), MessageType.Warning)
            self.abort()
            return
        else:
            self.data['format_name'] = inv_formats[selected_format]

        logger.debug('Start export')

        selected_labels = self.data['selected_labels']
        validation_ratio = self.data['validation_ratio']
        data_folder = self.data['data_folder']
        format_name = self.data['format_name']

        self.checkAborted()

        intermediate = IntermediateFormat()
        intermediate.setAbortable(self.abortable)
        intermediate.setThread(self.thread)
        intermediate.setIncludedLabels(selected_labels)
        intermediate.setValidationRatio(validation_ratio)
        intermediate.addFromLabelFiles(data_folder, shuffle=False)

        self.thread.update.emit(_('Loading data ...'), 0,
                                intermediate.getNumberOfSamples() + 5)

        args = Map({
            'validation_ratio': validation_ratio,
        })

        dataset_format = Export.config('objects')[format_name]()
        dataset_format.setAbortable(self.abortable)
        dataset_format.setThread(self.thread)
        dataset_format.setIntermediateFormat(intermediate)
        dataset_format.setInputFolderOrFile(data_folder)
        dataset_format.setOutputFolder(export_dataset_folder)
        dataset_format.setArgs(args)

        self.checkAborted()

        dataset_format.export()
Beispiel #16
0
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle(_('Export dataset'))
        self.set_default_window_flags(self)
        self.setWindowModality(Qt.NonModal)

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)

        self.formats = QtWidgets.QComboBox()
        for key, val in Export.config('formats').items():
            self.formats.addItem(val)

        format_group = QtWidgets.QGroupBox()
        format_group.setTitle(_('Format'))
        format_group_layout = QtWidgets.QVBoxLayout()
        format_group.setLayout(format_group_layout)
        format_group_layout.addWidget(self.formats)
        layout.addWidget(format_group)

        self.data_folder = QtWidgets.QLineEdit()
        self.data_folder.setReadOnly(True)

        data_folder_group = QtWidgets.QGroupBox()
        data_folder_group.setTitle(_('Data folder'))
        data_folder_group_layout = QtWidgets.QHBoxLayout()
        data_folder_group.setLayout(data_folder_group_layout)
        data_folder_group_layout.addWidget(self.data_folder)
        layout.addWidget(data_folder_group)

        self.label_checkboxes = []
        self.label_selection_label = QtWidgets.QLabel(_('Label selection'))
        self.label_selection_label.setVisible(False)
        layout.addWidget(self.label_selection_label)

        self.label_parent_widget = QtWidgets.QWidget()
        self.label_parent_widget.setLayout(QtWidgets.QVBoxLayout())

        self.label_selection_scroll = QtWidgets.QScrollArea()
        self.label_selection_scroll.setVisible(False)
        self.label_selection_scroll.setWidgetResizable(True)
        self.label_selection_scroll.setMinimumHeight(40)
        self.label_selection_scroll.setMaximumHeight(150)
        self.label_selection_scroll.setWidget(self.label_parent_widget)
        layout.addWidget(self.label_selection_scroll)

        labels = self.get_labels()
        if self.parent.lastOpenDir is not None and len(labels) > 0:
            self.data_folder.setText(self.parent.lastOpenDir)
            self.load_labels(labels)

        self.export_folder = QtWidgets.QLineEdit()
        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))
        self.export_folder.setText(
            os.path.join(project_folder,
                         self.parent._config['project_dataset_folder']))
        self.export_folder.setReadOnly(True)
        export_browse_btn = QtWidgets.QPushButton(_('Browse'))
        export_browse_btn.clicked.connect(self.export_browse_btn_clicked)

        export_name_label = QtWidgets.QLabel(_('Dataset name'))
        self.export_name = QtWidgets.QLineEdit()

        export_folder_group = QtWidgets.QGroupBox()
        export_folder_group.setTitle(_('Export folder'))
        export_folder_group_layout = QtWidgets.QGridLayout()
        export_folder_group.setLayout(export_folder_group_layout)
        export_folder_group_layout.addWidget(self.export_folder, 0, 0, 1, 2)
        export_folder_group_layout.addWidget(export_browse_btn, 0, 2)
        export_folder_group_layout.addWidget(export_name_label, 1, 0, 1, 3)
        export_folder_group_layout.addWidget(self.export_name, 2, 0, 1, 3)
        layout.addWidget(export_folder_group)

        self.validation = QtWidgets.QSpinBox()
        self.validation.setValue(0)
        self.validation.setMinimum(0)
        self.validation.setMaximum(90)
        self.validation.setFixedWidth(50)
        validation_label = QtWidgets.QLabel(_('% of dataset'))

        validation_group = QtWidgets.QGroupBox()
        validation_group.setTitle(_('Validation ratio'))
        validation_group_layout = QtWidgets.QGridLayout()
        validation_group.setLayout(validation_group_layout)
        validation_group_layout.addWidget(self.validation, 0, 0)
        validation_group_layout.addWidget(validation_label, 0, 1)
        layout.addWidget(validation_group)

        layout.addStretch()

        button_box = QtWidgets.QDialogButtonBox()
        export_btn = button_box.addButton(
            _('Export'), QtWidgets.QDialogButtonBox.AcceptRole)
        export_btn.clicked.connect(self.export_btn_clicked)
        cancel_btn = button_box.addButton(
            _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole)
        cancel_btn.clicked.connect(self.cancel_btn_clicked)
        layout.addWidget(button_box)
    def start_training(self):
        training_defaults = self.parent._config['training_defaults']

        network_key = self.get_current_network_key()
        epochs = self.args_epochs.value()

        resume_training_file = ''
        resume_epoch = 0
        if self.resume_training_checkbox.widget.isChecked(
        ) and self.resume_file.count() > 0:
            idx = self.resume_file.currentIndex()
            resume_training_file, resume_epoch = self.resume_file.itemData(idx)
            epochs += resume_epoch

        # Data
        data = {
            'create_dataset': self.create_dataset_checkbox.widget.isChecked(),
            'resume_training': resume_training_file,
            'start_epoch': resume_epoch,
            'dataset_export_data': self.dataset_export_data,
            'train_dataset': self.train_dataset_folder.text(),
            'val_dataset': self.val_dataset_folder.text(),
            'output_folder': self.output_folder.text(),
            'selected_format': self.selected_format,
            'training_name': self.training_name.text(),
            'network': network_key,
            'gpu_checkboxes': self.gpu_checkboxes,
            'args_epochs': epochs,
            'args_batch_size': self.args_batch_size.value(),
            'args_learning_rate': self.args_learning_rate.value(),
            'args_early_stop_epochs': self.args_early_stop_epochs.value(),
        }

        # Preprocess data
        mb = QtWidgets.QMessageBox()
        create_dataset = data['create_dataset']
        if create_dataset:
            export_data = data['dataset_export_data']
            format_name = export_data['format']
            dataset_format = Export.config('objects')[format_name]()
            output_folder = export_data['output_folder']
            train_file = dataset_format.getOutputFileName('train')
            train_dataset = os.path.join(output_folder, train_file)
            data['train_dataset'] = train_dataset
            validation_ratio = export_data['validation_ratio']
            if validation_ratio > 0:
                val_file = dataset_format.getOutputFileName('val')
                val_dataset = os.path.join(output_folder, val_file)
            else:
                # Validation dataset is optional
                val_dataset = False
            data['val_dataset'] = val_dataset

        else:
            train_dataset = data['train_dataset']
            is_train_dataset_valid = True
            if not train_dataset:
                is_train_dataset_valid = False
            train_dataset = os.path.normpath(train_dataset)
            if not (os.path.isdir(train_dataset)
                    or os.path.isfile(train_dataset)):
                is_train_dataset_valid = False
            if not is_train_dataset_valid:
                mb.warning(self, _('Training'),
                           _('Please select a valid training dataset'))
                return
            data['train_dataset'] = train_dataset

            val_dataset = data['val_dataset']
            is_val_dataset_valid = True
            if not val_dataset:
                is_val_dataset_valid = False
            val_dataset = os.path.normpath(val_dataset)
            if not (os.path.isdir(val_dataset) or os.path.isfile(val_dataset)):
                is_val_dataset_valid = False
            if not is_val_dataset_valid:
                # Validation dataset is optional
                val_dataset = False
            data['val_dataset'] = val_dataset

        if val_dataset and val_dataset == train_dataset:
            mb.warning(
                self, _('Training'),
                _('Training and validation dataset are equal. Please use different datasets, as validation results are useless otherwise.'
                  ))
            return

        output_folder = os.path.normpath(data['output_folder'])
        training_name = data['training_name']
        training_name = replace_special_chars(training_name)
        data['training_name'] = training_name

        if not training_name:
            mb.warning(self, _('Training'),
                       _('Please enter a valid training name'))
            return

        output_folder = os.path.join(output_folder, training_name)
        data['output_folder'] = output_folder
        if not os.path.isdir(output_folder):
            os.makedirs(output_folder)
        elif len(os.listdir(output_folder)) > 0 and resume_training_file:
            mb.warning(
                self, _('Training'),
                _('The selected output directory "{}" is not empty. Please select a different directory.'
                  ).format(output_folder))
            return
        elif len(os.listdir(output_folder)) > 0 and not resume_training_file:
            msg = _(
                'The selected output directory "{}" is not empty. All containing files will be deleted. Are you sure to continue?'
            ).format(output_folder)
            result = confirm(self, _('Training'), msg, MessageType.Warning)
            if result:
                deltree(output_folder)
                time.sleep(0.5)  # wait for deletion to be finished
                if not os.path.exists(output_folder):
                    os.makedirs(output_folder)
            else:
                return

        if not os.path.isdir(output_folder):
            mb.warning(
                self, _('Training'),
                _('The selected output directory "{}" could not be created').
                format(output_folder))
            return

        # Open new window for training progress
        trainingWin = TrainingProgressWindow(self)
        trainingWin.show()
        trainingWin.start_training(data)
        self.close()
 def __init__(self):
     super().__init__()
     self.intermediate = None
     self.num_samples = -1
     FormatImageRecord._files['labels'] = Export.config('labels_file')
Beispiel #19
0
    def run(self):
        logger.debug('Prepare import')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        data_folder_or_file = self.data['data_folder']
        is_data_folder_valid = True
        if not data_folder_or_file:
            is_data_folder_valid = False
        data_folder_or_file = os.path.normpath(data_folder_or_file)
        if not (os.path.isdir(data_folder_or_file)
                or os.path.isfile(data_folder_or_file)):
            is_data_folder_valid = False
        if not is_data_folder_valid:
            self.thread.message.emit(
                _('Import'), _('Please enter a valid dataset file or folder'),
                MessageType.Warning)
            self.abort()
            return

        output_folder = self.data['output_folder']
        is_output_folder_valid = True
        if not output_folder:
            is_output_folder_valid = False
        output_folder = os.path.normpath(output_folder)
        if not os.path.isdir(output_folder):
            is_output_folder_valid = False
        if not is_output_folder_valid:
            self.thread.message.emit(_('Import'),
                                     _('Please enter a valid output folder'),
                                     MessageType.Warning)
            self.abort()
            return

        selected_format = self.data['selected_format']
        all_formats = Export.config('formats')
        inv_formats = Export.invertDict(all_formats)
        if selected_format not in inv_formats:
            self.thread.message.emit(
                _('Import'),
                _('Import format {} could not be found').format(
                    selected_format), MessageType.Warning)
            self.abort()
            return
        else:
            self.data['format_name'] = inv_formats[selected_format]
        format_name = self.data['format_name']

        # Dataset
        dataset_format = Export.config('objects')[format_name]()
        if not dataset_format.isValidFormat(data_folder_or_file):
            self.thread.message.emit(_('Import'), _('Invalid dataset format'),
                                     MessageType.Warning)
            self.abort()
            return

        dataset_format.setAbortable(self.abortable)
        dataset_format.setThread(self.thread)
        dataset_format.setOutputFolder(output_folder)
        dataset_format.setInputFolderOrFile(data_folder_or_file)

        self.checkAborted()

        dataset_format.importFolder()
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle(_('Training'))
        self.set_default_window_flags(self)
        self.setWindowModality(Qt.NonModal)

        self.dataset_format_init = False
        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)

        self.tabs = QtWidgets.QTabWidget()
        layout.addWidget(self.tabs)

        tab_dataset = QtWidgets.QWidget()
        tab_dataset_layout = QtWidgets.QVBoxLayout()
        tab_dataset.setLayout(tab_dataset_layout)
        self.tabs.addTab(tab_dataset, _('Dataset'))

        tab_network = QtWidgets.QWidget()
        tab_network_layout = QtWidgets.QVBoxLayout()
        tab_network.setLayout(tab_network_layout)
        self.tabs.addTab(tab_network, _('Network'))

        tab_resume = QtWidgets.QWidget()
        tab_resume_layout = QtWidgets.QVBoxLayout()
        tab_resume.setLayout(tab_resume_layout)
        self.tabs.addTab(tab_resume, _('Resume training'))

        # Network Tab

        self.networks = QtWidgets.QComboBox()
        for key, val in Training.config('networks').items():
            self.networks.addItem(val)
        self.networks.currentIndexChanged.connect(
            self.network_selection_changed)

        network_group = HelpGroupBox('Training_NetworkArchitecture',
                                     _('Network'))
        network_group_layout = QtWidgets.QVBoxLayout()
        network_group.widget.setLayout(network_group_layout)
        network_group_layout.addWidget(self.networks)
        tab_network_layout.addWidget(network_group)

        training_defaults = self.parent._config['training_defaults']
        network = self.get_current_network()

        args_epochs_label = HelpLabel('Training_SettingsEpochs', _('Epochs'))
        self.args_epochs = QtWidgets.QSpinBox()
        self.args_epochs.setMinimum(1)
        self.args_epochs.setMaximum(500)
        self.args_epochs.setValue(training_defaults['epochs'])

        default_batch_size = self.get_default_batch_size(network)
        args_batch_size_label = HelpLabel('Training_SettingsBatchSize',
                                          _('Batch size'))
        self.args_batch_size = QtWidgets.QSpinBox()
        self.args_batch_size.setMinimum(1)
        self.args_batch_size.setMaximum(100)
        self.args_batch_size.setValue(default_batch_size)

        default_learning_rate = self.get_default_learning_rate(network)
        args_learning_rate_label = HelpLabel('Training_SettingsLearningRate',
                                             _('Learning rate'))
        self.args_learning_rate = QtWidgets.QDoubleSpinBox()
        self.args_learning_rate.setMinimum(1e-7)
        self.args_learning_rate.setMaximum(1.0)
        self.args_learning_rate.setSingleStep(1e-7)
        self.args_learning_rate.setDecimals(7)
        self.args_learning_rate.setValue(default_learning_rate)

        args_early_stop_epochs_label = HelpLabel('Training_SettingsEarlyStop',
                                                 _('Early stop epochs'))
        self.args_early_stop_epochs = QtWidgets.QSpinBox()
        self.args_early_stop_epochs.setMinimum(0)
        self.args_early_stop_epochs.setMaximum(100)
        self.args_early_stop_epochs.setValue(
            training_defaults['early_stop_epochs'])

        self.gpu_label_text = _('GPU')
        args_gpus_label = QtWidgets.QLabel(_('GPUs'))
        no_gpus_available_label = QtWidgets.QLabel(_('No GPUs available'))
        self.gpus = mx.test_utils.list_gpus()  # ['0']
        self.gpu_checkboxes = []
        for i in self.gpus:
            checkbox = QtWidgets.QCheckBox('{} {}'.format(
                self.gpu_label_text, i))
            checkbox.setChecked(i == 0)
            self.gpu_checkboxes.append(checkbox)

        settings_group = QtWidgets.QGroupBox()
        settings_group.setTitle(_('Settings'))
        settings_group_layout = QtWidgets.QGridLayout()
        settings_group.setLayout(settings_group_layout)
        settings_group_layout.addWidget(args_epochs_label, 0, 0)
        settings_group_layout.addWidget(self.args_epochs, 0, 1)
        settings_group_layout.addWidget(args_batch_size_label, 1, 0)
        settings_group_layout.addWidget(self.args_batch_size, 1, 1)
        settings_group_layout.addWidget(args_learning_rate_label, 2, 0)
        settings_group_layout.addWidget(self.args_learning_rate, 2, 1)
        settings_group_layout.addWidget(args_early_stop_epochs_label, 3, 0)
        settings_group_layout.addWidget(self.args_early_stop_epochs, 3, 1)

        settings_group_layout.addWidget(args_gpus_label, 4, 0)
        if len(self.gpu_checkboxes) > 0:
            row = 4
            for i, checkbox in enumerate(self.gpu_checkboxes):
                settings_group_layout.addWidget(checkbox, row, 1)
                row += 1
        else:
            settings_group_layout.addWidget(no_gpus_available_label, 4, 1)
        tab_network_layout.addWidget(settings_group)

        # Dataset Tab

        image_list = self.parent.imageList
        show_dataset_create = len(image_list) > 0
        self.create_dataset_checkbox = HelpCheckbox(
            'Training_CreateDataset', _('Create dataset from opened images'))
        self.create_dataset_checkbox.widget.setChecked(show_dataset_create)
        tab_dataset_layout.addWidget(self.create_dataset_checkbox)

        validation_label = HelpLabel('Training_ValidationRatio',
                                     _('Validation ratio'))
        self.validation = QtWidgets.QSpinBox()
        self.validation.setValue(10)
        self.validation.setMinimum(0)
        self.validation.setMaximum(90)
        self.validation.setFixedWidth(50)
        validation_description_label = QtWidgets.QLabel(_('% of dataset'))

        dataset_name_label = QtWidgets.QLabel(_('Dataset name'))
        self.dataset_name = QtWidgets.QLineEdit()

        self.create_dataset_group = QtWidgets.QGroupBox()
        self.create_dataset_group.setTitle(_('Create dataset'))
        create_dataset_group_layout = QtWidgets.QGridLayout()
        self.create_dataset_group.setLayout(create_dataset_group_layout)
        create_dataset_group_layout.addWidget(dataset_name_label, 0, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.dataset_name, 1, 0, 1, 2)
        create_dataset_group_layout.addWidget(validation_label, 2, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.validation, 3, 0)
        create_dataset_group_layout.addWidget(validation_description_label, 3,
                                              1)

        formats_label = HelpLabel('Training_DatasetFormat', _('Format'))
        self.formats = QtWidgets.QComboBox()
        for key, val in Export.config('formats').items():
            self.formats.addItem(val)
        self.formats.setCurrentIndex(0)
        self.formats.currentTextChanged.connect(self.on_format_change)
        self.selected_format = list(Export.config('formats').keys())[0]

        train_dataset_label = HelpLabel('Training_TrainingDataset',
                                        _('Training dataset'))
        self.train_dataset_folder = QtWidgets.QLineEdit()
        train_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        train_dataset_folder_browse_btn.clicked.connect(
            self.train_dataset_folder_browse_btn_clicked)

        val_label_text = '{} ({})'.format(_('Validation dataset'),
                                          _('optional'))
        val_dataset_label = HelpLabel('Training_ValidationDataset',
                                      val_label_text)
        self.val_dataset_folder = QtWidgets.QLineEdit()
        val_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        val_dataset_folder_browse_btn.clicked.connect(
            self.val_dataset_folder_browse_btn_clicked)

        self.dataset_folder_group = QtWidgets.QGroupBox()
        self.dataset_folder_group.setTitle(_('Use dataset file(s)'))
        dataset_folder_group_layout = QtWidgets.QGridLayout()
        self.dataset_folder_group.setLayout(dataset_folder_group_layout)
        dataset_folder_group_layout.addWidget(formats_label, 0, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.formats, 1, 0, 1, 2)
        dataset_folder_group_layout.addWidget(train_dataset_label, 2, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.train_dataset_folder, 3, 0)
        dataset_folder_group_layout.addWidget(train_dataset_folder_browse_btn,
                                              3, 1)
        dataset_folder_group_layout.addWidget(val_dataset_label, 4, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.val_dataset_folder, 5, 0)
        dataset_folder_group_layout.addWidget(val_dataset_folder_browse_btn, 5,
                                              1)

        tab_dataset_layout.addWidget(self.create_dataset_group)
        tab_dataset_layout.addWidget(self.dataset_folder_group)

        if show_dataset_create:
            self.dataset_folder_group.hide()
        else:
            self.create_dataset_checkbox.hide()
            self.create_dataset_group.hide()

        self.create_dataset_checkbox.widget.toggled.connect(
            lambda: self.switch_visibility(self.create_dataset_group, self.
                                           dataset_folder_group))

        self.output_folder = QtWidgets.QLineEdit()
        self.output_folder.setText(
            os.path.join(project_folder,
                         self.parent._config['project_training_folder']))
        # self.output_folder.setReadOnly(True)
        # output_browse_btn = QtWidgets.QPushButton(_('Browse'))
        # output_browse_btn.clicked.connect(self.output_browse_btn_clicked)

        training_name_label = HelpLabel('Training_TrainingName',
                                        _('Training name'))
        self.training_name = QtWidgets.QLineEdit()

        output_folder_group = QtWidgets.QGroupBox()
        output_folder_group.setTitle(_('Output folder'))
        output_folder_group_layout = QtWidgets.QGridLayout()
        output_folder_group.setLayout(output_folder_group_layout)
        # output_folder_group_layout.addWidget(self.output_folder, 0, 0, 1, 2)
        # output_folder_group_layout.addWidget(output_browse_btn, 0, 2)
        output_folder_group_layout.addWidget(training_name_label, 1, 0, 1, 3)
        output_folder_group_layout.addWidget(self.training_name, 2, 0, 1, 3)
        tab_dataset_layout.addWidget(output_folder_group)

        # Resume Tab

        self.resume_training_checkbox = HelpCheckbox(
            'Training_Resume', _('Resume previous training'))
        self.resume_training_checkbox.widget.setChecked(False)
        tab_resume_layout.addWidget(self.resume_training_checkbox)

        self.resume_group = QtWidgets.QWidget()
        resume_group_layout = QtWidgets.QVBoxLayout()
        self.resume_group.setLayout(resume_group_layout)
        tab_resume_layout.addWidget(self.resume_group)

        self.resume_group.hide()
        self.resume_training_checkbox.widget.toggled.connect(
            self.toggle_resume_training_checkbox)

        self.resume_folder = QtWidgets.QLineEdit()
        self.resume_folder.setText('')
        self.resume_folder.setReadOnly(True)
        resume_browse_btn = QtWidgets.QPushButton(_('Browse'))
        resume_browse_btn.clicked.connect(self.resume_browse_btn_clicked)

        resume_folder_group = QtWidgets.QGroupBox()
        resume_folder_group.setTitle(_('Training directory'))
        resume_folder_group_layout = QtWidgets.QHBoxLayout()
        resume_folder_group.setLayout(resume_folder_group_layout)
        resume_folder_group_layout.addWidget(self.resume_folder)
        resume_folder_group_layout.addWidget(resume_browse_btn)
        resume_group_layout.addWidget(resume_folder_group)

        self.resume_file = QtWidgets.QComboBox()
        self.resume_file_group = QtWidgets.QGroupBox()
        self.resume_file_group.setTitle(_('Resume after'))
        resume_file_group_layout = QtWidgets.QHBoxLayout()
        self.resume_file_group.setLayout(resume_file_group_layout)
        resume_file_group_layout.addWidget(self.resume_file)
        resume_group_layout.addWidget(self.resume_file_group)
        self.resume_file_group.hide()

        tab_dataset_layout.addStretch()
        tab_network_layout.addStretch()
        tab_resume_layout.addStretch()

        button_box = QtWidgets.QDialogButtonBox()
        self.training_btn = button_box.addButton(
            _('Start Training'), QtWidgets.QDialogButtonBox.AcceptRole)
        self.training_btn.clicked.connect(self.training_btn_clicked)
        cancel_btn = button_box.addButton(
            _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole)
        cancel_btn.clicked.connect(self.cancel_btn_clicked)
        layout.addWidget(button_box)

        h = self.sizeHint().height()
        self.resize(500, h)
Beispiel #21
0
 def __init__(self):
     super().__init__()
     self.intermediate = None
     self.dataset = None
     self.num_samples = -1
     FormatCoco._files['labels'] = Export.config('labels_file')
Beispiel #22
0
    def run(self):
        logger.debug('Prepare training')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        network_key = self.data['network']
        if network_key not in Training.config('objects'):
            self.thread.message.emit(
                _('Training'),
                _('Network {} could not be found').format(network_key),
                MessageType.Error)
            self.abort()
            return

        # Training settings
        gpus = []
        gpu_checkboxes = self.data['gpu_checkboxes']
        for i, gpu in enumerate(gpu_checkboxes):
            if gpu.checkState() == Qt.Checked:
                gpus.append(str(i))
        gpus = ','.join(gpus)
        epochs = int(self.data['args_epochs'])
        batch_size = int(self.data['args_batch_size'])

        # Dataset
        dataset_format = self.data['selected_format']
        train_dataset_obj = Export.config('objects')[dataset_format]()
        train_dataset_obj.setInputFolderOrFile(self.data['train_dataset'])
        if self.data['val_dataset']:
            val_dataset_obj = Export.config('objects')[dataset_format]()
            val_dataset_obj.setInputFolderOrFile(self.data['val_dataset'])

        labels = train_dataset_obj.getLabels()
        num_train_samples = train_dataset_obj.getNumSamples()
        num_batches = int(math.ceil(num_train_samples / batch_size))

        args = Map({
            'network':
            self.data['network'],
            'train_dataset':
            self.data['train_dataset'],
            'validate_dataset':
            self.data['val_dataset'],
            'training_name':
            self.data['training_name'],
            'batch_size':
            batch_size,
            'learning_rate':
            float(self.data['args_learning_rate']),
            'gpus':
            gpus,
            'epochs':
            epochs,
            'early_stop_epochs':
            int(self.data['args_early_stop_epochs']),
            'start_epoch':
            self.data['start_epoch'],
            'resume':
            self.data['resume_training'],
        })

        self.thread.update.emit(_('Loading data ...'), 0,
                                epochs * num_batches + 5)

        with Training.config('objects')[network_key]() as network:
            network.setAbortable(self.abortable)
            network.setThread(self.thread)
            network.setArgs(args)
            network.setOutputFolder(self.data['output_folder'])
            network.setTrainDataset(train_dataset_obj, dataset_format)
            network.setLabels(labels)

            if self.data['val_dataset']:
                network.setValDataset(val_dataset_obj)

            self.checkAborted()

            network.training()