def get_current_network(self):
     try:
         selected_network = self.networks.currentText()
         networks = Training.config('networks')
         func_name = None
         for key in networks:
             if selected_network in networks[key]:
                 func_name = key
                 break
         if func_name is not None:
             network = Training.config('objects')[func_name]()
             return network
     except Exception as e:
         logger.error(traceback.format_exc())
         return None
    def update_config_info(self, config):
        if 'network' in config:
            network_name = config['network']
            if network_name in Training.config('networks'):
                network_name = Training.config('networks')[network_name]
            self.config_network_value.setText(network_name)
        if 'training_name' in config['args']:
            self.config_training_name_value.setText(
                str(config['args']['training_name']))
        if 'epochs' in config['args']:
            self.config_epochs_value.setText(str(config['args']['epochs']))
        if 'last_epoch' in config:
            self.config_trained_epochs_value.setText(str(config['last_epoch']))
        if 'early_stop_epochs' in config['args']:
            self.config_early_stop_value.setText(
                str(config['args']['early_stop_epochs']))
        if 'batch_size' in config['args']:
            self.config_batch_size_value.setText(
                str(config['args']['batch_size']))
        if 'learning_rate' in config['args']:
            self.config_learning_rate_value.setText(
                str(config['args']['learning_rate']))
        if 'dataset' in config:
            dataset_format = config['dataset']
            if dataset_format in Export.config('formats'):
                dataset_format = Export.config('formats')[dataset_format]
            self.config_dataset_format_value.setText(dataset_format)

        project_folder = self.parent.parent.settings.value(
            'settings/project/folder', '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        if 'train_dataset' in config['args']:
            train_dataset = config['args']['train_dataset']
            prefix = os.path.commonprefix([train_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                train_dataset = _(
                    'Project folder') + train_dataset[len(prefix):]
            self.config_dataset_train_value.setText(train_dataset)
        if 'validate_dataset' in config['args']:
            validate_dataset = config['args']['validate_dataset']
            prefix = os.path.commonprefix([validate_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                validate_dataset = _(
                    'Project folder') + validate_dataset[len(prefix):]
            self.config_dataset_val_value.setText(validate_dataset)
示例#3
0
 def afterEpoch(self, epoch):
     from labelme.config import Training
     config_file = os.path.join(self.output_folder,
                                Training.config('config_file'))
     self.updateConfig(config_file, last_epoch=epoch + 1)
     self.thread.update.emit(
         _('Finished training on epoch {}').format(epoch + 1), None, -1)
     self.checkTrainingAborted(epoch)
 def set_current_network(self, network_key):
     try:
         networks = Training.config('networks')
         network_text = networks[network_key]
         for i in range(self.networks.count()):
             text = self.networks.itemText(i)
             if text == network_text:
                 self.networks.setCurrentIndex(i)
                 break
     except Exception as e:
         logger.error(traceback.format_exc())
 def get_current_network_key(self):
     try:
         selected_network = self.networks.currentText()
         networks = Training.config('networks')
         func_name = None
         for key in networks:
             if selected_network in networks[key]:
                 return key
     except Exception as e:
         logger.error(traceback.format_exc())
         return None
    def run(self):
        logger.debug('Prepare inference')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        training_folder = self.data['training_folder']
        input_image_file = self.data['input_image_file']

        config_file = os.path.join(training_folder,
                                   Training.config('config_file'))

        network = Network()
        network.setAbortable(self.abortable)
        network.setThread(self.thread)

        network_config = network.loadConfig(config_file)
        self.thread.data.emit({'network_config': network_config})

        architecture_file = ''
        weights_file = ''
        files = network_config.files
        for f in files:
            if '.json' in f:
                architecture_file = os.path.join(training_folder, f)
            elif '.params' in f:
                weights_file = os.path.join(training_folder, f)

        dataset_folder = network_config.dataset

        inference_data = Map({
            'input_image_file': input_image_file,
            'architecture_file': architecture_file,
            'weights_file': weights_file,
            'labels': network_config.labels,
        })

        self.thread.update.emit(_('Validating ...'), 1, 3)

        network.inference(inference_data.input_image_file,
                          inference_data.labels,
                          inference_data.architecture_file,
                          inference_data.weights_file,
                          args=None)

        self.thread.update.emit(_('Finished'), 3, 3)
示例#7
0
    def beforeTrain(self):
        # Save config & architecture before training
        from labelme.config import Training
        config_file = os.path.join(self.output_folder,
                                   Training.config('config_file'))
        files = list(self.files.values())
        self.saveConfig(config_file, files)

        if self.args.early_stop_epochs > 0:
            self.monitor = NetworkMonitor(self.args.early_stop_epochs)

        self.thread.data.emit({
            'validation': {
                _('Waiting for first validation ...'): '',
            },
        })

        # Set maximum progress to 100%
        self.thread.update.emit(_('Start training ...'), 0, 100)

        self.checkAborted()
        logger.info('Start training from [Epoch {}]'.format(
            self.args.start_epoch))
    def prepare_resume_training(self, training_folder):
        try:
            config_file = os.path.join(training_folder,
                                       Training.config('config_file'))
            json_data = {}
            with open(config_file, 'r') as f:
                json_data = json.load(f)

            # Dataset tab
            self.create_dataset_checkbox.widget.setChecked(False)
            self.set_current_format(json_data['dataset'])
            self.train_dataset_folder.setText(
                json_data['args']['train_dataset'])
            if json_data['args']['validate_dataset']:
                self.val_dataset_folder.setText(
                    json_data['args']['validate_dataset'])
            else:
                self.val_dataset_folder.setText('')
            output_folder = os.path.normpath(
                os.path.join(training_folder, '..'))
            self.output_folder.setText(output_folder)
            self.training_name.setText(json_data['args']['training_name'] +
                                       '_resume')

            # Network tab
            self.set_current_network(json_data['network'])
            self.args_epochs.setValue(json_data['args']['epochs'])
            self.args_batch_size.setValue(json_data['args']['batch_size'])
            self.args_learning_rate.setValue(
                json_data['args']['learning_rate'])
            self.args_early_stop_epochs.setValue(
                json_data['args']['early_stop_epochs'])
            gpus = json_data['args']['gpus'].split(',')
            l = len(self.gpu_label_text) + 1
            for box in self.gpu_checkboxes:
                checked = box.text()[l:] in gpus
                box.setChecked(checked)

            # Resume
            self.resume_file.clear()
            params_pattern = '{}_*_*.params'.format(
                json_data['args']['save_prefix'])
            params_path = os.path.join(training_folder, params_pattern)
            for params_file in reversed(glob.glob(params_path)):
                parts = os.path.splitext(params_file)[0].split('_')
                epoch = int(parts[-2]) + 1
                accuracy = float(parts[-1])
                label = '{} {} ({}={})'.format(_('Epoch'), epoch,
                                               _('Accuracy'), accuracy)
                self.resume_file.addItem(label, (params_file, epoch))
            self.resume_file.setCurrentIndex(0)
            self.resume_file_group.show()

            self.training_btn.setText(_('Resume training'))

        except Exception as e:
            logger.error(traceback.format_exc())
            mb = QtWidgets.QMessageBox
            mb.warning(self, _('Training'),
                       _('Applying config of previous training failed'))
            self.resume_folder.setText('')
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle(_('Training'))
        self.set_default_window_flags(self)
        self.setWindowModality(Qt.NonModal)

        self.dataset_format_init = False
        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)

        self.tabs = QtWidgets.QTabWidget()
        layout.addWidget(self.tabs)

        tab_dataset = QtWidgets.QWidget()
        tab_dataset_layout = QtWidgets.QVBoxLayout()
        tab_dataset.setLayout(tab_dataset_layout)
        self.tabs.addTab(tab_dataset, _('Dataset'))

        tab_network = QtWidgets.QWidget()
        tab_network_layout = QtWidgets.QVBoxLayout()
        tab_network.setLayout(tab_network_layout)
        self.tabs.addTab(tab_network, _('Network'))

        tab_resume = QtWidgets.QWidget()
        tab_resume_layout = QtWidgets.QVBoxLayout()
        tab_resume.setLayout(tab_resume_layout)
        self.tabs.addTab(tab_resume, _('Resume training'))

        # Network Tab

        self.networks = QtWidgets.QComboBox()
        for key, val in Training.config('networks').items():
            self.networks.addItem(val)
        self.networks.currentIndexChanged.connect(
            self.network_selection_changed)

        network_group = HelpGroupBox('Training_NetworkArchitecture',
                                     _('Network'))
        network_group_layout = QtWidgets.QVBoxLayout()
        network_group.widget.setLayout(network_group_layout)
        network_group_layout.addWidget(self.networks)
        tab_network_layout.addWidget(network_group)

        training_defaults = self.parent._config['training_defaults']
        network = self.get_current_network()

        args_epochs_label = HelpLabel('Training_SettingsEpochs', _('Epochs'))
        self.args_epochs = QtWidgets.QSpinBox()
        self.args_epochs.setMinimum(1)
        self.args_epochs.setMaximum(500)
        self.args_epochs.setValue(training_defaults['epochs'])

        default_batch_size = self.get_default_batch_size(network)
        args_batch_size_label = HelpLabel('Training_SettingsBatchSize',
                                          _('Batch size'))
        self.args_batch_size = QtWidgets.QSpinBox()
        self.args_batch_size.setMinimum(1)
        self.args_batch_size.setMaximum(100)
        self.args_batch_size.setValue(default_batch_size)

        default_learning_rate = self.get_default_learning_rate(network)
        args_learning_rate_label = HelpLabel('Training_SettingsLearningRate',
                                             _('Learning rate'))
        self.args_learning_rate = QtWidgets.QDoubleSpinBox()
        self.args_learning_rate.setMinimum(1e-7)
        self.args_learning_rate.setMaximum(1.0)
        self.args_learning_rate.setSingleStep(1e-7)
        self.args_learning_rate.setDecimals(7)
        self.args_learning_rate.setValue(default_learning_rate)

        args_early_stop_epochs_label = HelpLabel('Training_SettingsEarlyStop',
                                                 _('Early stop epochs'))
        self.args_early_stop_epochs = QtWidgets.QSpinBox()
        self.args_early_stop_epochs.setMinimum(0)
        self.args_early_stop_epochs.setMaximum(100)
        self.args_early_stop_epochs.setValue(
            training_defaults['early_stop_epochs'])

        self.gpu_label_text = _('GPU')
        args_gpus_label = QtWidgets.QLabel(_('GPUs'))
        no_gpus_available_label = QtWidgets.QLabel(_('No GPUs available'))
        self.gpus = mx.test_utils.list_gpus()  # ['0']
        self.gpu_checkboxes = []
        for i in self.gpus:
            checkbox = QtWidgets.QCheckBox('{} {}'.format(
                self.gpu_label_text, i))
            checkbox.setChecked(i == 0)
            self.gpu_checkboxes.append(checkbox)

        settings_group = QtWidgets.QGroupBox()
        settings_group.setTitle(_('Settings'))
        settings_group_layout = QtWidgets.QGridLayout()
        settings_group.setLayout(settings_group_layout)
        settings_group_layout.addWidget(args_epochs_label, 0, 0)
        settings_group_layout.addWidget(self.args_epochs, 0, 1)
        settings_group_layout.addWidget(args_batch_size_label, 1, 0)
        settings_group_layout.addWidget(self.args_batch_size, 1, 1)
        settings_group_layout.addWidget(args_learning_rate_label, 2, 0)
        settings_group_layout.addWidget(self.args_learning_rate, 2, 1)
        settings_group_layout.addWidget(args_early_stop_epochs_label, 3, 0)
        settings_group_layout.addWidget(self.args_early_stop_epochs, 3, 1)

        settings_group_layout.addWidget(args_gpus_label, 4, 0)
        if len(self.gpu_checkboxes) > 0:
            row = 4
            for i, checkbox in enumerate(self.gpu_checkboxes):
                settings_group_layout.addWidget(checkbox, row, 1)
                row += 1
        else:
            settings_group_layout.addWidget(no_gpus_available_label, 4, 1)
        tab_network_layout.addWidget(settings_group)

        # Dataset Tab

        image_list = self.parent.imageList
        show_dataset_create = len(image_list) > 0
        self.create_dataset_checkbox = HelpCheckbox(
            'Training_CreateDataset', _('Create dataset from opened images'))
        self.create_dataset_checkbox.widget.setChecked(show_dataset_create)
        tab_dataset_layout.addWidget(self.create_dataset_checkbox)

        validation_label = HelpLabel('Training_ValidationRatio',
                                     _('Validation ratio'))
        self.validation = QtWidgets.QSpinBox()
        self.validation.setValue(10)
        self.validation.setMinimum(0)
        self.validation.setMaximum(90)
        self.validation.setFixedWidth(50)
        validation_description_label = QtWidgets.QLabel(_('% of dataset'))

        dataset_name_label = QtWidgets.QLabel(_('Dataset name'))
        self.dataset_name = QtWidgets.QLineEdit()

        self.create_dataset_group = QtWidgets.QGroupBox()
        self.create_dataset_group.setTitle(_('Create dataset'))
        create_dataset_group_layout = QtWidgets.QGridLayout()
        self.create_dataset_group.setLayout(create_dataset_group_layout)
        create_dataset_group_layout.addWidget(dataset_name_label, 0, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.dataset_name, 1, 0, 1, 2)
        create_dataset_group_layout.addWidget(validation_label, 2, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.validation, 3, 0)
        create_dataset_group_layout.addWidget(validation_description_label, 3,
                                              1)

        formats_label = HelpLabel('Training_DatasetFormat', _('Format'))
        self.formats = QtWidgets.QComboBox()
        for key, val in Export.config('formats').items():
            self.formats.addItem(val)
        self.formats.setCurrentIndex(0)
        self.formats.currentTextChanged.connect(self.on_format_change)
        self.selected_format = list(Export.config('formats').keys())[0]

        train_dataset_label = HelpLabel('Training_TrainingDataset',
                                        _('Training dataset'))
        self.train_dataset_folder = QtWidgets.QLineEdit()
        train_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        train_dataset_folder_browse_btn.clicked.connect(
            self.train_dataset_folder_browse_btn_clicked)

        val_label_text = '{} ({})'.format(_('Validation dataset'),
                                          _('optional'))
        val_dataset_label = HelpLabel('Training_ValidationDataset',
                                      val_label_text)
        self.val_dataset_folder = QtWidgets.QLineEdit()
        val_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        val_dataset_folder_browse_btn.clicked.connect(
            self.val_dataset_folder_browse_btn_clicked)

        self.dataset_folder_group = QtWidgets.QGroupBox()
        self.dataset_folder_group.setTitle(_('Use dataset file(s)'))
        dataset_folder_group_layout = QtWidgets.QGridLayout()
        self.dataset_folder_group.setLayout(dataset_folder_group_layout)
        dataset_folder_group_layout.addWidget(formats_label, 0, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.formats, 1, 0, 1, 2)
        dataset_folder_group_layout.addWidget(train_dataset_label, 2, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.train_dataset_folder, 3, 0)
        dataset_folder_group_layout.addWidget(train_dataset_folder_browse_btn,
                                              3, 1)
        dataset_folder_group_layout.addWidget(val_dataset_label, 4, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.val_dataset_folder, 5, 0)
        dataset_folder_group_layout.addWidget(val_dataset_folder_browse_btn, 5,
                                              1)

        tab_dataset_layout.addWidget(self.create_dataset_group)
        tab_dataset_layout.addWidget(self.dataset_folder_group)

        if show_dataset_create:
            self.dataset_folder_group.hide()
        else:
            self.create_dataset_checkbox.hide()
            self.create_dataset_group.hide()

        self.create_dataset_checkbox.widget.toggled.connect(
            lambda: self.switch_visibility(self.create_dataset_group, self.
                                           dataset_folder_group))

        self.output_folder = QtWidgets.QLineEdit()
        self.output_folder.setText(
            os.path.join(project_folder,
                         self.parent._config['project_training_folder']))
        # self.output_folder.setReadOnly(True)
        # output_browse_btn = QtWidgets.QPushButton(_('Browse'))
        # output_browse_btn.clicked.connect(self.output_browse_btn_clicked)

        training_name_label = HelpLabel('Training_TrainingName',
                                        _('Training name'))
        self.training_name = QtWidgets.QLineEdit()

        output_folder_group = QtWidgets.QGroupBox()
        output_folder_group.setTitle(_('Output folder'))
        output_folder_group_layout = QtWidgets.QGridLayout()
        output_folder_group.setLayout(output_folder_group_layout)
        # output_folder_group_layout.addWidget(self.output_folder, 0, 0, 1, 2)
        # output_folder_group_layout.addWidget(output_browse_btn, 0, 2)
        output_folder_group_layout.addWidget(training_name_label, 1, 0, 1, 3)
        output_folder_group_layout.addWidget(self.training_name, 2, 0, 1, 3)
        tab_dataset_layout.addWidget(output_folder_group)

        # Resume Tab

        self.resume_training_checkbox = HelpCheckbox(
            'Training_Resume', _('Resume previous training'))
        self.resume_training_checkbox.widget.setChecked(False)
        tab_resume_layout.addWidget(self.resume_training_checkbox)

        self.resume_group = QtWidgets.QWidget()
        resume_group_layout = QtWidgets.QVBoxLayout()
        self.resume_group.setLayout(resume_group_layout)
        tab_resume_layout.addWidget(self.resume_group)

        self.resume_group.hide()
        self.resume_training_checkbox.widget.toggled.connect(
            self.toggle_resume_training_checkbox)

        self.resume_folder = QtWidgets.QLineEdit()
        self.resume_folder.setText('')
        self.resume_folder.setReadOnly(True)
        resume_browse_btn = QtWidgets.QPushButton(_('Browse'))
        resume_browse_btn.clicked.connect(self.resume_browse_btn_clicked)

        resume_folder_group = QtWidgets.QGroupBox()
        resume_folder_group.setTitle(_('Training directory'))
        resume_folder_group_layout = QtWidgets.QHBoxLayout()
        resume_folder_group.setLayout(resume_folder_group_layout)
        resume_folder_group_layout.addWidget(self.resume_folder)
        resume_folder_group_layout.addWidget(resume_browse_btn)
        resume_group_layout.addWidget(resume_folder_group)

        self.resume_file = QtWidgets.QComboBox()
        self.resume_file_group = QtWidgets.QGroupBox()
        self.resume_file_group.setTitle(_('Resume after'))
        resume_file_group_layout = QtWidgets.QHBoxLayout()
        self.resume_file_group.setLayout(resume_file_group_layout)
        resume_file_group_layout.addWidget(self.resume_file)
        resume_group_layout.addWidget(self.resume_file_group)
        self.resume_file_group.hide()

        tab_dataset_layout.addStretch()
        tab_network_layout.addStretch()
        tab_resume_layout.addStretch()

        button_box = QtWidgets.QDialogButtonBox()
        self.training_btn = button_box.addButton(
            _('Start Training'), QtWidgets.QDialogButtonBox.AcceptRole)
        self.training_btn.clicked.connect(self.training_btn_clicked)
        cancel_btn = button_box.addButton(
            _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole)
        cancel_btn.clicked.connect(self.cancel_btn_clicked)
        layout.addWidget(button_box)

        h = self.sizeHint().height()
        self.resize(500, h)
示例#10
0
    def run(self):
        logger.debug('Prepare training')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        network_key = self.data['network']
        if network_key not in Training.config('objects'):
            self.thread.message.emit(
                _('Training'),
                _('Network {} could not be found').format(network_key),
                MessageType.Error)
            self.abort()
            return

        # Training settings
        gpus = []
        gpu_checkboxes = self.data['gpu_checkboxes']
        for i, gpu in enumerate(gpu_checkboxes):
            if gpu.checkState() == Qt.Checked:
                gpus.append(str(i))
        gpus = ','.join(gpus)
        epochs = int(self.data['args_epochs'])
        batch_size = int(self.data['args_batch_size'])

        # Dataset
        dataset_format = self.data['selected_format']
        train_dataset_obj = Export.config('objects')[dataset_format]()
        train_dataset_obj.setInputFolderOrFile(self.data['train_dataset'])
        if self.data['val_dataset']:
            val_dataset_obj = Export.config('objects')[dataset_format]()
            val_dataset_obj.setInputFolderOrFile(self.data['val_dataset'])

        labels = train_dataset_obj.getLabels()
        num_train_samples = train_dataset_obj.getNumSamples()
        num_batches = int(math.ceil(num_train_samples / batch_size))

        args = Map({
            'network':
            self.data['network'],
            'train_dataset':
            self.data['train_dataset'],
            'validate_dataset':
            self.data['val_dataset'],
            'training_name':
            self.data['training_name'],
            'batch_size':
            batch_size,
            'learning_rate':
            float(self.data['args_learning_rate']),
            'gpus':
            gpus,
            'epochs':
            epochs,
            'early_stop_epochs':
            int(self.data['args_early_stop_epochs']),
            'start_epoch':
            self.data['start_epoch'],
            'resume':
            self.data['resume_training'],
        })

        self.thread.update.emit(_('Loading data ...'), 0,
                                epochs * num_batches + 5)

        with Training.config('objects')[network_key]() as network:
            network.setAbortable(self.abortable)
            network.setThread(self.thread)
            network.setArgs(args)
            network.setOutputFolder(self.data['output_folder'])
            network.setTrainDataset(train_dataset_obj, dataset_format)
            network.setLabels(labels)

            if self.data['val_dataset']:
                network.setValDataset(val_dataset_obj)

            self.checkAborted()

            network.training()