def get_default_batch_size(self, network):
        try:
            selected_network = self.networks.currentText()
            network_size_base, network_size_per_batch = network.getGpuSizes()

            import GPUtil
            gpus = GPUtil.getGPUs()
            gpu = gpus[0]

            # Estimate best possible batch size
            # Always take 1GB off to have memory left for peaks
            batch_size = int(
                math.floor((gpu.memoryFree - network_size_base) /
                           network_size_per_batch))
            estimated_memory = gpu.memoryUsed + network_size_base + batch_size * network_size_per_batch
            logger.debug(
                'Estimating batch size: GPU {} (ID:{}) uses {}MB of {}MB. With {} ({}MB, {}MB) the estimated GPU usage is {}MB at a batch size of {}'
                .format(gpu.name, gpu.id, gpu.memoryUsed, gpu.memoryTotal,
                        selected_network, network_size_base,
                        network_size_per_batch, estimated_memory, batch_size))

            return batch_size

        except Exception as e:
            logger.error(traceback.format_exc())
    def validateEpoch(self, epoch, epoch_time, validate_params):
        self.checkTrainingAborted(epoch)
        if self.val_data and not (epoch + 1) % self.args.val_interval:
            logger.debug('validate: {}'.format(epoch + 1))
            self.thread.data.emit({
                'validation': {
                    _('Validating...'): '',
                },
                'progress': {
                    'speed': 0,
                }
            })

            map_name, mean_ap = self.validate(**validate_params)
            val_msg = '\n'.join(
                ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
            logger.info('[Epoch {}] Validation [{:.3f}sec]: \n{}'.format(
                epoch, epoch_time, val_msg))
            current_mAP = float(mean_ap[-1])

            val_data = {'validation': {}}
            for i, name in enumerate(map_name[:]):
                val_data['validation'][name] = mean_ap[i]
            self.thread.data.emit(val_data)

            # Early Stopping
            self.monitor.update(epoch, mean_ap[-1])
            if self.monitor.shouldStopEarly():
                raise AbortTrainingException(epoch)
        else:
            current_mAP = 0.
        return current_mAP
    def getLabels(self):
        labels = []
        input_folder = os.path.dirname(self.input_folder_or_file)
        label_file = os.path.join(input_folder,
                                  FormatImageRecord._files['labels'])
        if os.path.isfile(label_file):
            logger.debug('Load labels from file {}'.format(label_file))
            for i, line in enumerate(open(label_file).readlines()):
                labels.append(line)
        else:
            labels = set([])
            logger.debug('No label file found. Start reading dataset')
            record = mx.recordio.MXRecordIO(self.input_folder_or_file, 'r')
            record.reset()
            self.num_samples = 0
            while True:
                try:
                    item = record.read()
                    if not item:
                        break
                    header, s = mx.recordio.unpack(item)
                    for i in range(4, len(header.label), 5):
                        label_idx = str(header.label[i])
                        labels.append(label_idx)
                        self.num_samples = self.num_samples + 1
                except Exception as e:
                    logger.error(traceback.format_exc())
            record.close()

        return list(labels)
 def dataset_folder_browse_btn_clicked(self, mode='train'):
     ext_filter = False
     extension = Export.config('extensions')[self.selected_format]
     format_name = Export.config('formats')[self.selected_format]
     if extension != False:
         ext_filter = '{} {}({})'.format(format_name, _('files'), extension)
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     dataset_folder = os.path.join(
         project_folder, self.parent._config['project_dataset_folder'])
     if ext_filter:
         dataset_folder_or_file, selected_filter = QtWidgets.QFileDialog.getOpenFileName(
             self, _('Select dataset file'), dataset_folder, ext_filter)
     else:
         dataset_folder_or_file = QtWidgets.QFileDialog.getExistingDirectory(
             self, _('Select dataset folder'), dataset_folder)
     if dataset_folder_or_file:
         dataset_folder_or_file = os.path.normpath(dataset_folder_or_file)
         key = Export.detectDatasetFormat(dataset_folder_or_file)
         logger.debug('Detected dataset format {} for directory {}'.format(
             key, dataset_folder_or_file))
         if key is None:
             mb = QtWidgets.QMessageBox()
             mb.warning(self, _('Training'),
                        _('Could not detect format of selected dataset'))
             return False
         return dataset_folder_or_file
     return False
Beispiel #5
0
def deltree(target):
    for d in os.listdir(target):
        try:
            deltree(target + '/' + d)
        except OSError:
            os.remove(target + '/' + d)
    os.rmdir(target)
    logger.debug('Deleted folder {}'.format(target))
 def loadConfig(self, config_file):
     logger.debug('Load training config from file: {}'.format(config_file))
     with open(config_file, 'r') as f:
         data = json.load(f)
         logger.debug('Loaded training config: {}'.format(data))
         return Map(data)
     raise Exception(
         'Could not load training config from file {}'.format(config_file))
    def makeRecFile(self, data_folder, output_folder, file_name_rec,
                    file_name_idx, file_name_lst):
        image_list = self.readLstFile(
            os.path.join(output_folder, file_name_lst))
        record = mx.recordio.MXIndexedRecordIO(
            os.path.join(output_folder, file_name_idx),
            os.path.join(output_folder, file_name_rec), 'w')

        self.checkAborted()

        args = Map({
            'root': data_folder,
            'pass_through': True,
            'resize': 0,
            'center_crop': False,
            'quality': 95,
            'encoding': '.jpg',
            'pack_label': True,
        })
        try:
            import Queue as queue
        except ImportError:
            import queue
        q_out = queue.Queue()
        cnt = 0
        pre_time = time.time()
        failed_images = []
        for i, item in enumerate(image_list):
            try:
                self.imageEncode(args, i, item, q_out)
                if q_out.empty():
                    continue
                _a, s, _b = q_out.get()
                record.write_idx(item[0], s)
                if cnt % 1000 == 0:
                    cur_time = time.time()
                    logger.debug('time: {} count: {}'.format(
                        cur_time - pre_time, cnt))
                    pre_time = cur_time
                cnt += 1
                self.thread.update.emit(_('Writing dataset ...'), -1, -1)
                self.checkAborted()

            except Exception as e:
                failed_images.append(item)
                logger.error(traceback.format_exc())

        if len(failed_images) > 0:
            msg = _('The following images could not be exported:'
                    ) + '\n' + ', '.join(failed_images)
            self.thread.message.emit(_('Warning'), msg, MessageType.Warning)
            if cnt == 0:
                self.throwUserException(
                    _('Dataset contains no images for export'))

        logger.debug('total time: {} total count: {}'.format(
            time.time() - pre_time, cnt))
 def on_finish(self):
     logger.debug('on_finish')
     self.cancel_progress()
     if not self.worker_executor.isAborted():
         self.reset_thread()
         if self.finish_func is not None:
             self.finish_func()
         else:
             self.close()
 def on_format_change(self, value):
     formats = Export.config('formats')
     inv_formats = Export.invertDict(formats)
     if value in inv_formats:
         self.selected_format = inv_formats[value]
         logger.debug('Selected dataset format: {}'.format(
             self.selected_format))
     else:
         logger.debug('Dataset format not found: {}'.format(value))
Beispiel #10
0
    def finish_training(self):
        data = self.data
        logger.debug('finish_training: {}'.format(data))
        self.progress_bar.setValue(4)

        mb = QtWidgets.QMessageBox()
        mb.information(self, _('Training'),
                       _('Network has been trained successfully'))
        self.reset_thread()
        self.close()
 def updateConfig(self, config_file, **kwargs):
     logger.debug('Update training config with data: {}'.format(kwargs))
     data = {}
     with open(config_file, 'r') as f:
         data = json.load(f)
     for key in kwargs.keys():
         data[key] = kwargs[key]
     with open(config_file, 'w+') as f:
         json.dump(data, f, indent=2)
     logger.debug('Updated training config in file: {}'.format(config_file))
Beispiel #12
0
 def export_browse_btn_clicked(self):
     last_dir = self.parent.settings.value('merge/last_export_dir', '')
     logger.debug('Restored value "{}" for setting merge/last_export_dir'.format(last_dir))
     # TODO: Replace config_file_extension
     filters = _('Dataset file') + ' (*{})'.format(Export.config('config_file_extension'))
     export_file, selected_filter = QtWidgets.QFileDialog.getSaveFileName(self, _('Save output file as'), last_dir, filters)
     if export_file:
         export_file = os.path.normpath(export_file)
         self.parent.settings.setValue('merge/last_export_dir', os.path.dirname(export_file))
         self.export_file.setText(export_file)
 def on_error(self, error_msg, is_custom_msg=False):
     logger.debug('on_error')
     self.cancel_progress()
     mb = QtWidgets.QMessageBox()
     if is_custom_msg:
         mb.warning(self, _('Error'), error_msg)
     else:
         mb.warning(
             self, _('Error'),
             _('An error occured. For further details look into the log files'
               ))
 def on_abort(self):
     logger.debug('on_abort')
     if isinstance(self.progress, QtWidgets.QProgressDialog):
         self.progress.setLabelText(_('Cancelling ...'))
     self.progress.setMaximum(0)
     if self.current_worker_object:
         self.current_worker_object.abort()
     worker = Application.getWorker(self.current_worker_idx)
     if worker:
         worker.wait()
     self.cancel_progress()
 def getContext(self, gpus=None):
     if gpus is None or gpus == '':
         return [mx.cpu()]
     ctx = [mx.gpu(int(i)) for i in gpus.split(',') if i.strip()]
     try:
         tmp = mx.nd.array([1, 2, 3], ctx=ctx[0])
     except mx.MXNetError as e:
         ctx = [mx.cpu()]
         logger.error(traceback.format_exc())
         logger.warning('Unable to use GPU. Using CPU instead')
     logger.debug('Use context: {}'.format(ctx))
     return ctx
    def export_before_training(self):
        training_defaults = self.parent._config['training_defaults']
        selected_format = training_defaults['dataset_format']
        dataset_name = replace_special_chars(self.dataset_name.text())

        data_folder = None
        if self.parent.lastOpenDir is not None:
            data_folder = self.parent.lastOpenDir
        if data_folder is None:
            mb = QtWidgets.QMessageBox()
            mb.warning(self, _('Training'),
                       _('Please open a folder with images first'))
            return

        all_labels = []
        for i in range(len(self.parent.uniqLabelList)):
            all_labels.append(self.parent.uniqLabelList.item(i).text())
        if len(all_labels) == 0:
            mb = QtWidgets.QMessageBox()
            mb.warning(self, _('Training'), _('No labels found in dataset'))
            return

        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        project_dataset_folder = self.parent._config['project_dataset_folder']
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))
        export_folder = os.path.join(project_folder, project_dataset_folder)

        validation_ratio = int(self.validation.text()) / 100.0

        self.dataset_export_data = {
            'dataset_name': dataset_name,
            'format': selected_format,
            'output_folder': os.path.join(export_folder, dataset_name),
            'validation_ratio': validation_ratio,
        }
        self.selected_format = selected_format

        data = {
            'data_folder': data_folder,
            'export_folder': export_folder,
            'selected_labels': all_labels,
            'validation_ratio': validation_ratio,
            'dataset_name': dataset_name,
            'max_num_labels': Export.config('limits')['max_num_labels'],
            'selected_format': Export.config('formats')[selected_format],
        }

        # Execution
        executor = ExportExecutor(data)
        self.run_thread(executor, self.start_training)
 def training_folder_browse_btn_clicked(self):
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     training_folder = os.path.join(
         project_folder, self.parent._config['project_training_folder'])
     training_folder = QtWidgets.QFileDialog.getExistingDirectory(
         self, _('Select training directory'), training_folder)
     if training_folder:
         training_folder = os.path.normpath(training_folder)
         self.training_folder.setText(training_folder)
Beispiel #18
0
 def export_browse_btn_clicked(self):
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     export_folder = os.path.join(
         project_folder, self.parent._config['project_dataset_folder'])
     export_folder = QtWidgets.QFileDialog.getExistingDirectory(
         self, _('Select output folder'), export_folder)
     if export_folder:
         export_folder = os.path.normpath(export_folder)
         self.export_folder.setText(export_folder)
 def input_image_file_browse_btn_clicked(self):
     last_dir = self.parent.settings.value(
         'validation/last_input_image_dir', '')
     logger.debug(
         'Restored value "{}" for setting validation/last_input_image_dir'.
         format(last_dir))
     filters = _('Image files') + ' (*.jpg *.jpeg *.png *.bmp)'
     image_file, selected_filter = QtWidgets.QFileDialog.getOpenFileName(
         self, _('Select input image'), last_dir, filters)
     if image_file:
         image_file = os.path.normpath(image_file)
         self.parent.settings.setValue('validation/last_input_image_dir',
                                       os.path.dirname(image_file))
         self.input_image_file.setText(image_file)
    def run(self):
        logger.debug('Prepare inference')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        training_folder = self.data['training_folder']
        input_image_file = self.data['input_image_file']

        config_file = os.path.join(training_folder,
                                   Training.config('config_file'))

        network = Network()
        network.setAbortable(self.abortable)
        network.setThread(self.thread)

        network_config = network.loadConfig(config_file)
        self.thread.data.emit({'network_config': network_config})

        architecture_file = ''
        weights_file = ''
        files = network_config.files
        for f in files:
            if '.json' in f:
                architecture_file = os.path.join(training_folder, f)
            elif '.params' in f:
                weights_file = os.path.join(training_folder, f)

        dataset_folder = network_config.dataset

        inference_data = Map({
            'input_image_file': input_image_file,
            'architecture_file': architecture_file,
            'weights_file': weights_file,
            'labels': network_config.labels,
        })

        self.thread.update.emit(_('Validating ...'), 1, 3)

        network.inference(inference_data.input_image_file,
                          inference_data.labels,
                          inference_data.architecture_file,
                          inference_data.weights_file,
                          args=None)

        self.thread.update.emit(_('Finished'), 3, 3)
 def getTrainValidateSamplesPerLabel(self, label, shuffle=False):
     if label not in self.samplesPerLabel:
         raise Exception('Label {} not found'.format(label))
     train_samples = self.samplesPerLabel[label]
     val_samples = []
     if shuffle:
         random.shuffle(train_samples)
     if self.validation_ratio > 0.0:
         num_val_samples = int(len(train_samples) * self.validation_ratio)
         logger.debug('Use {} validate samples for label {}'.format(
             num_val_samples, label))
         for i in range(num_val_samples):
             val_samples.append(train_samples.pop(-1))
             self.checkAborted()
     return train_samples, val_samples
Beispiel #22
0
 def output_browse_btn_clicked(self):
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     import_folder = os.path.join(
         project_folder, self.parent._config['project_import_folder'])
     output_folder = QtWidgets.QFileDialog.getExistingDirectory(
         self, _('Select output folder'), import_folder)
     if output_folder:
         output_folder = os.path.normpath(output_folder)
         self.parent.settings.setValue('import/last_output_dir',
                                       output_folder)
         self.output_folder.setText(output_folder)
    def update_config_info(self, config):
        if 'network' in config:
            network_name = config['network']
            if network_name in Training.config('networks'):
                network_name = Training.config('networks')[network_name]
            self.config_network_value.setText(network_name)
        if 'training_name' in config['args']:
            self.config_training_name_value.setText(
                str(config['args']['training_name']))
        if 'epochs' in config['args']:
            self.config_epochs_value.setText(str(config['args']['epochs']))
        if 'last_epoch' in config:
            self.config_trained_epochs_value.setText(str(config['last_epoch']))
        if 'early_stop_epochs' in config['args']:
            self.config_early_stop_value.setText(
                str(config['args']['early_stop_epochs']))
        if 'batch_size' in config['args']:
            self.config_batch_size_value.setText(
                str(config['args']['batch_size']))
        if 'learning_rate' in config['args']:
            self.config_learning_rate_value.setText(
                str(config['args']['learning_rate']))
        if 'dataset' in config:
            dataset_format = config['dataset']
            if dataset_format in Export.config('formats'):
                dataset_format = Export.config('formats')[dataset_format]
            self.config_dataset_format_value.setText(dataset_format)

        project_folder = self.parent.parent.settings.value(
            'settings/project/folder', '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        if 'train_dataset' in config['args']:
            train_dataset = config['args']['train_dataset']
            prefix = os.path.commonprefix([train_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                train_dataset = _(
                    'Project folder') + train_dataset[len(prefix):]
            self.config_dataset_train_value.setText(train_dataset)
        if 'validate_dataset' in config['args']:
            validate_dataset = config['args']['validate_dataset']
            prefix = os.path.commonprefix([validate_dataset, project_folder])
            if len(prefix) >= len(project_folder):
                validate_dataset = _(
                    'Project folder') + validate_dataset[len(prefix):]
            self.config_dataset_val_value.setText(validate_dataset)
 def saveConfig(self, config_file, files):
     args = self.args.copy()
     del args['network']
     del args['resume']
     data = {
         'network': self.args.network,
         'files': files,
         'dataset': self.dataset_format,
         'labels': self.labels,
         'args': args,
     }
     logger.debug('Create training config: {}'.format(data))
     with open(config_file, 'w+') as f:
         json.dump(data, f, indent=2)
         logger.debug(
             'Saved training config in file: {}'.format(config_file))
 def getDataloader(self, train_dataset, val_dataset):
     width, height = self.args.data_shape, self.args.data_shape
     batchify_fn = Tuple(*([Stack() for foo in range(6)] + [Pad(axis=0, pad_val=-1) for bar in range(1)]))  # stack image, all targets generated
     if self.args.no_random_shape:
         logger.debug('no random shape')
         train_loader = gluon.data.DataLoader(
             train_dataset.transform(YOLO3DefaultTrainTransform(width, height, self.net, mixup=self.args.mixup)),
             self.args.batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=self.args.num_workers)
     else:
         logger.debug('with random shape')
         transform_fns = [YOLO3DefaultTrainTransform(x * 32, x * 32, self.net, mixup=self.args.mixup) for x in range(10, 20)]
         train_loader = RandomTransformDataLoader(
             transform_fns, train_dataset, batch_size=self.args.batch_size, interval=10, last_batch='rollover',
             shuffle=True, batchify_fn=batchify_fn, num_workers=self.args.num_workers)
     val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
     val_loader = None
     if val_dataset is not None:
         val_loader = gluon.data.DataLoader(val_dataset.transform(YOLO3DefaultValTransform(width, height)),
             self.args.batch_size, True, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=self.args.num_workers)
     return train_loader, val_loader
Beispiel #26
0
def get_config(config_file_or_yaml=None, config_from_args=None):
    # 1. default config
    config = get_default_config()

    # 2. specified as file or yaml
    if config_file_or_yaml is not None:
        config_from_yaml = yaml.safe_load(config_file_or_yaml)
        if not isinstance(config_from_yaml, dict):
            with open(config_from_yaml) as f:
                logger.debug(
                    'Loading config file from: {}'.format(config_from_yaml))
                config_from_yaml = yaml.safe_load(f)
        update_dict(config,
                    config_from_yaml,
                    validate_item=validate_config_item)

    # 3. command line argument or specified config file
    if config_from_args is not None:
        update_dict(config,
                    config_from_args,
                    validate_item=validate_config_item)

    return config
Beispiel #27
0
 def data_browse_btn_clicked(self):
     ext_filter = False
     extension = Export.config('extensions')[self.selected_format]
     format_name = Export.config('formats')[self.selected_format]
     if extension != False:
         ext_filter = '{} {}({})'.format(format_name, _('files'), extension)
     project_folder = self.parent.settings.value('settings/project/folder',
                                                 '')
     logger.debug(
         'Restored value "{}" for setting settings/project/folder'.format(
             project_folder))
     dataset_folder = os.path.join(
         project_folder, self.parent._config['project_dataset_folder'])
     if ext_filter:
         import_file_or_dir, selected_filter = QtWidgets.QFileDialog.getOpenFileName(
             self, _('Select dataset file'), dataset_folder, ext_filter)
     else:
         import_file_or_dir = QtWidgets.QFileDialog.getExistingDirectory(
             self, _('Select dataset folder'), dataset_folder)
     if import_file_or_dir:
         import_file_or_dir = os.path.normpath(import_file_or_dir)
         self.parent.settings.setValue('import/last_data_dir',
                                       os.path.dirname(import_file_or_dir))
         self.data_folder.setText(import_file_or_dir)
Beispiel #28
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--version', '-V', action='store_true', help='show version'
    )
    parser.add_argument(
        '--reset-config', action='store_true', help='reset qt config'
    )
    parser.add_argument(
        '--logger-level',
        default='info',
        choices=['debug', 'info', 'warning', 'fatal', 'error'],
        help='logger level',
    )
    parser.add_argument('filename', nargs='?', help='image or label filename')
    parser.add_argument(
        '--output',
        '-O',
        '-o',
        help='output file or directory (if it ends with .json it is '
             'recognized as file, else as directory)'
    )
    default_config_file = os.path.join(os.path.expanduser('~'), '.labelmerc')
    parser.add_argument(
        '--config',
        dest='config_file',
        help='config file (default: %s)' % default_config_file,
        default=default_config_file,
    )
    # config for the gui
    parser.add_argument(
        '--nodata',
        dest='store_data',
        action='store_false',
        help='stop storing image data to JSON file',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--autosave',
        dest='auto_save',
        action='store_true',
        help='auto save',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--nosortlabels',
        dest='sort_labels',
        action='store_false',
        help='stop sorting labels',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--flags',
        help='comma separated list of flags OR file containing flags',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--labelflags',
        dest='label_flags',
        help='yaml string of label specific flags OR file containing json '
             'string of label specific flags (ex. {person-\d+: [male, tall], '
             'dog-\d+: [black, brown, white], .*: [occluded]})',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--labels',
        help='comma separated list of labels OR file containing labels',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--validatelabel',
        dest='validate_label',
        choices=['exact', 'instance'],
        help='label validation types',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--keep-prev',
        action='store_true',
        help='keep annotation of previous frame',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--epsilon',
        type=float,
        help='epsilon to find nearest vertex on canvas',
        default=argparse.SUPPRESS,
    )
    parser.add_argument(
        '--debug-mode',
        dest='debug_mode',
        action='store_true',
        help='start in debug mode',
        default=argparse.SUPPRESS,
    )
    args = parser.parse_args()

    if args.version:
        print('{0} {1}'.format(__appname__, __version__))
        sys.exit(0)

    if hasattr(args, 'debug_mode'):
        logger.addStreamHandler(logging.DEBUG)
    else:
        level = args.logger_level.upper()
        logger.addStreamHandler(getattr(logging, level))

    if hasattr(args, 'flags'):
        if os.path.isfile(args.flags):
            with codecs.open(args.flags, 'r', encoding='utf-8') as f:
                args.flags = [l.strip() for l in f if l.strip()]
        else:
            args.flags = [l for l in args.flags.split(',') if l]

    if hasattr(args, 'labels'):
        if os.path.isfile(args.labels):
            with codecs.open(args.labels, 'r', encoding='utf-8') as f:
                args.labels = [l.strip() for l in f if l.strip()]
        else:
            args.labels = [l for l in args.labels.split(',') if l]

    if hasattr(args, 'label_flags'):
        if os.path.isfile(args.label_flags):
            with codecs.open(args.label_flags, 'r', encoding='utf-8') as f:
                args.label_flags = yaml.load(f)
        else:
            args.label_flags = yaml.load(args.label_flags)

    config_from_args = args.__dict__
    config_from_args.pop('version')
    reset_config = config_from_args.pop('reset_config')
    filename = config_from_args.pop('filename')
    output = config_from_args.pop('output')
    config_file = config_from_args.pop('config_file')
    config = get_config(config_from_args, config_file)

    # localization
    current_path = os.path.dirname(os.path.abspath(__file__))
    locale_dir = os.path.join(current_path, 'locale')
    if os.path.isfile(os.path.join(locale_dir, config['language'], 'LC_MESSAGES', 'labelme.po')):
        lang = gettext.translation('labelme', localedir=locale_dir, languages=[config['language']])
        lang.install()
    else:
        gettext.install('labelme')
    locale.setlocale(locale.LC_ALL, config['language'])

    if not config['labels'] and config['validate_label']:
        logger.error('--labels must be specified with --validatelabel or '
                     'validate_label: true in the config file '
                     '(ex. ~/.labelmerc).')
        sys.exit(1)

    output_file = None
    output_dir = None
    if output is not None:
        if output.endswith('.json'):
            output_file = output
        else:
            output_dir = output

    # MXNet environment variables
    if 'MXNET_GPU_MEM_POOL_TYPE' in os.environ:
        logger.debug('Environment variable MXNET_GPU_MEM_POOL_TYPE = {}'.format(os.environ['MXNET_GPU_MEM_POOL_TYPE']))
    if 'MXNET_CUDNN_AUTOTUNE_DEFAULT' in os.environ:
        logger.debug('Environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT = {}'.format(os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT']))
    if 'MXNET_HOME' in os.environ:
        logger.debug('Environment variable MXNET_HOME = {}'.format(os.environ['MXNET_HOME']))

    # # Qt environment variable
    # if 'QT_AUTO_SCREEN_SCALE_FACTOR' in os.environ:
    #     logger.debug('Environment variable QT_AUTO_SCREEN_SCALE_FACTOR = {}'.format(os.environ['QT_AUTO_SCREEN_SCALE_FACTOR']))

    # Enable high dpi for 4k monitors
    QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
    if hasattr(QtWidgets.QStyleFactory, 'AA_UseHighDpiPixmaps'):
        QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps)
    
    app = QtWidgets.QApplication(sys.argv)
    app.setApplicationName(__appname__)
    app.setWindowIcon(newIcon('digivod'))
    win = MainWindow(
        config=config,
        filename=filename,
        output_file=output_file,
        output_dir=output_dir,
    )

    if reset_config:
        logger.info('Resetting Qt config: %s' % win.settings.fileName())
        win.settings.clear()
        sys.exit(0)

    win.showMaximized()
    win.raise_()
    sys.excepthook = excepthook
    win.check_startup()
    sys.exit(app.exec_())
Beispiel #29
0
 def getNumSamples(self):
     logger.debug('Count samples in dataset')
     dataset = self._loadDataset()
     return len(dataset)
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle(_('Training'))
        self.set_default_window_flags(self)
        self.setWindowModality(Qt.NonModal)

        self.dataset_format_init = False
        project_folder = self.parent.settings.value('settings/project/folder',
                                                    '')
        logger.debug(
            'Restored value "{}" for setting settings/project/folder'.format(
                project_folder))

        layout = QtWidgets.QVBoxLayout()
        self.setLayout(layout)

        self.tabs = QtWidgets.QTabWidget()
        layout.addWidget(self.tabs)

        tab_dataset = QtWidgets.QWidget()
        tab_dataset_layout = QtWidgets.QVBoxLayout()
        tab_dataset.setLayout(tab_dataset_layout)
        self.tabs.addTab(tab_dataset, _('Dataset'))

        tab_network = QtWidgets.QWidget()
        tab_network_layout = QtWidgets.QVBoxLayout()
        tab_network.setLayout(tab_network_layout)
        self.tabs.addTab(tab_network, _('Network'))

        tab_resume = QtWidgets.QWidget()
        tab_resume_layout = QtWidgets.QVBoxLayout()
        tab_resume.setLayout(tab_resume_layout)
        self.tabs.addTab(tab_resume, _('Resume training'))

        # Network Tab

        self.networks = QtWidgets.QComboBox()
        for key, val in Training.config('networks').items():
            self.networks.addItem(val)
        self.networks.currentIndexChanged.connect(
            self.network_selection_changed)

        network_group = HelpGroupBox('Training_NetworkArchitecture',
                                     _('Network'))
        network_group_layout = QtWidgets.QVBoxLayout()
        network_group.widget.setLayout(network_group_layout)
        network_group_layout.addWidget(self.networks)
        tab_network_layout.addWidget(network_group)

        training_defaults = self.parent._config['training_defaults']
        network = self.get_current_network()

        args_epochs_label = HelpLabel('Training_SettingsEpochs', _('Epochs'))
        self.args_epochs = QtWidgets.QSpinBox()
        self.args_epochs.setMinimum(1)
        self.args_epochs.setMaximum(500)
        self.args_epochs.setValue(training_defaults['epochs'])

        default_batch_size = self.get_default_batch_size(network)
        args_batch_size_label = HelpLabel('Training_SettingsBatchSize',
                                          _('Batch size'))
        self.args_batch_size = QtWidgets.QSpinBox()
        self.args_batch_size.setMinimum(1)
        self.args_batch_size.setMaximum(100)
        self.args_batch_size.setValue(default_batch_size)

        default_learning_rate = self.get_default_learning_rate(network)
        args_learning_rate_label = HelpLabel('Training_SettingsLearningRate',
                                             _('Learning rate'))
        self.args_learning_rate = QtWidgets.QDoubleSpinBox()
        self.args_learning_rate.setMinimum(1e-7)
        self.args_learning_rate.setMaximum(1.0)
        self.args_learning_rate.setSingleStep(1e-7)
        self.args_learning_rate.setDecimals(7)
        self.args_learning_rate.setValue(default_learning_rate)

        args_early_stop_epochs_label = HelpLabel('Training_SettingsEarlyStop',
                                                 _('Early stop epochs'))
        self.args_early_stop_epochs = QtWidgets.QSpinBox()
        self.args_early_stop_epochs.setMinimum(0)
        self.args_early_stop_epochs.setMaximum(100)
        self.args_early_stop_epochs.setValue(
            training_defaults['early_stop_epochs'])

        self.gpu_label_text = _('GPU')
        args_gpus_label = QtWidgets.QLabel(_('GPUs'))
        no_gpus_available_label = QtWidgets.QLabel(_('No GPUs available'))
        self.gpus = mx.test_utils.list_gpus()  # ['0']
        self.gpu_checkboxes = []
        for i in self.gpus:
            checkbox = QtWidgets.QCheckBox('{} {}'.format(
                self.gpu_label_text, i))
            checkbox.setChecked(i == 0)
            self.gpu_checkboxes.append(checkbox)

        settings_group = QtWidgets.QGroupBox()
        settings_group.setTitle(_('Settings'))
        settings_group_layout = QtWidgets.QGridLayout()
        settings_group.setLayout(settings_group_layout)
        settings_group_layout.addWidget(args_epochs_label, 0, 0)
        settings_group_layout.addWidget(self.args_epochs, 0, 1)
        settings_group_layout.addWidget(args_batch_size_label, 1, 0)
        settings_group_layout.addWidget(self.args_batch_size, 1, 1)
        settings_group_layout.addWidget(args_learning_rate_label, 2, 0)
        settings_group_layout.addWidget(self.args_learning_rate, 2, 1)
        settings_group_layout.addWidget(args_early_stop_epochs_label, 3, 0)
        settings_group_layout.addWidget(self.args_early_stop_epochs, 3, 1)

        settings_group_layout.addWidget(args_gpus_label, 4, 0)
        if len(self.gpu_checkboxes) > 0:
            row = 4
            for i, checkbox in enumerate(self.gpu_checkboxes):
                settings_group_layout.addWidget(checkbox, row, 1)
                row += 1
        else:
            settings_group_layout.addWidget(no_gpus_available_label, 4, 1)
        tab_network_layout.addWidget(settings_group)

        # Dataset Tab

        image_list = self.parent.imageList
        show_dataset_create = len(image_list) > 0
        self.create_dataset_checkbox = HelpCheckbox(
            'Training_CreateDataset', _('Create dataset from opened images'))
        self.create_dataset_checkbox.widget.setChecked(show_dataset_create)
        tab_dataset_layout.addWidget(self.create_dataset_checkbox)

        validation_label = HelpLabel('Training_ValidationRatio',
                                     _('Validation ratio'))
        self.validation = QtWidgets.QSpinBox()
        self.validation.setValue(10)
        self.validation.setMinimum(0)
        self.validation.setMaximum(90)
        self.validation.setFixedWidth(50)
        validation_description_label = QtWidgets.QLabel(_('% of dataset'))

        dataset_name_label = QtWidgets.QLabel(_('Dataset name'))
        self.dataset_name = QtWidgets.QLineEdit()

        self.create_dataset_group = QtWidgets.QGroupBox()
        self.create_dataset_group.setTitle(_('Create dataset'))
        create_dataset_group_layout = QtWidgets.QGridLayout()
        self.create_dataset_group.setLayout(create_dataset_group_layout)
        create_dataset_group_layout.addWidget(dataset_name_label, 0, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.dataset_name, 1, 0, 1, 2)
        create_dataset_group_layout.addWidget(validation_label, 2, 0, 1, 2)
        create_dataset_group_layout.addWidget(self.validation, 3, 0)
        create_dataset_group_layout.addWidget(validation_description_label, 3,
                                              1)

        formats_label = HelpLabel('Training_DatasetFormat', _('Format'))
        self.formats = QtWidgets.QComboBox()
        for key, val in Export.config('formats').items():
            self.formats.addItem(val)
        self.formats.setCurrentIndex(0)
        self.formats.currentTextChanged.connect(self.on_format_change)
        self.selected_format = list(Export.config('formats').keys())[0]

        train_dataset_label = HelpLabel('Training_TrainingDataset',
                                        _('Training dataset'))
        self.train_dataset_folder = QtWidgets.QLineEdit()
        train_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        train_dataset_folder_browse_btn.clicked.connect(
            self.train_dataset_folder_browse_btn_clicked)

        val_label_text = '{} ({})'.format(_('Validation dataset'),
                                          _('optional'))
        val_dataset_label = HelpLabel('Training_ValidationDataset',
                                      val_label_text)
        self.val_dataset_folder = QtWidgets.QLineEdit()
        val_dataset_folder_browse_btn = QtWidgets.QPushButton(_('Browse'))
        val_dataset_folder_browse_btn.clicked.connect(
            self.val_dataset_folder_browse_btn_clicked)

        self.dataset_folder_group = QtWidgets.QGroupBox()
        self.dataset_folder_group.setTitle(_('Use dataset file(s)'))
        dataset_folder_group_layout = QtWidgets.QGridLayout()
        self.dataset_folder_group.setLayout(dataset_folder_group_layout)
        dataset_folder_group_layout.addWidget(formats_label, 0, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.formats, 1, 0, 1, 2)
        dataset_folder_group_layout.addWidget(train_dataset_label, 2, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.train_dataset_folder, 3, 0)
        dataset_folder_group_layout.addWidget(train_dataset_folder_browse_btn,
                                              3, 1)
        dataset_folder_group_layout.addWidget(val_dataset_label, 4, 0, 1, 2)
        dataset_folder_group_layout.addWidget(self.val_dataset_folder, 5, 0)
        dataset_folder_group_layout.addWidget(val_dataset_folder_browse_btn, 5,
                                              1)

        tab_dataset_layout.addWidget(self.create_dataset_group)
        tab_dataset_layout.addWidget(self.dataset_folder_group)

        if show_dataset_create:
            self.dataset_folder_group.hide()
        else:
            self.create_dataset_checkbox.hide()
            self.create_dataset_group.hide()

        self.create_dataset_checkbox.widget.toggled.connect(
            lambda: self.switch_visibility(self.create_dataset_group, self.
                                           dataset_folder_group))

        self.output_folder = QtWidgets.QLineEdit()
        self.output_folder.setText(
            os.path.join(project_folder,
                         self.parent._config['project_training_folder']))
        # self.output_folder.setReadOnly(True)
        # output_browse_btn = QtWidgets.QPushButton(_('Browse'))
        # output_browse_btn.clicked.connect(self.output_browse_btn_clicked)

        training_name_label = HelpLabel('Training_TrainingName',
                                        _('Training name'))
        self.training_name = QtWidgets.QLineEdit()

        output_folder_group = QtWidgets.QGroupBox()
        output_folder_group.setTitle(_('Output folder'))
        output_folder_group_layout = QtWidgets.QGridLayout()
        output_folder_group.setLayout(output_folder_group_layout)
        # output_folder_group_layout.addWidget(self.output_folder, 0, 0, 1, 2)
        # output_folder_group_layout.addWidget(output_browse_btn, 0, 2)
        output_folder_group_layout.addWidget(training_name_label, 1, 0, 1, 3)
        output_folder_group_layout.addWidget(self.training_name, 2, 0, 1, 3)
        tab_dataset_layout.addWidget(output_folder_group)

        # Resume Tab

        self.resume_training_checkbox = HelpCheckbox(
            'Training_Resume', _('Resume previous training'))
        self.resume_training_checkbox.widget.setChecked(False)
        tab_resume_layout.addWidget(self.resume_training_checkbox)

        self.resume_group = QtWidgets.QWidget()
        resume_group_layout = QtWidgets.QVBoxLayout()
        self.resume_group.setLayout(resume_group_layout)
        tab_resume_layout.addWidget(self.resume_group)

        self.resume_group.hide()
        self.resume_training_checkbox.widget.toggled.connect(
            self.toggle_resume_training_checkbox)

        self.resume_folder = QtWidgets.QLineEdit()
        self.resume_folder.setText('')
        self.resume_folder.setReadOnly(True)
        resume_browse_btn = QtWidgets.QPushButton(_('Browse'))
        resume_browse_btn.clicked.connect(self.resume_browse_btn_clicked)

        resume_folder_group = QtWidgets.QGroupBox()
        resume_folder_group.setTitle(_('Training directory'))
        resume_folder_group_layout = QtWidgets.QHBoxLayout()
        resume_folder_group.setLayout(resume_folder_group_layout)
        resume_folder_group_layout.addWidget(self.resume_folder)
        resume_folder_group_layout.addWidget(resume_browse_btn)
        resume_group_layout.addWidget(resume_folder_group)

        self.resume_file = QtWidgets.QComboBox()
        self.resume_file_group = QtWidgets.QGroupBox()
        self.resume_file_group.setTitle(_('Resume after'))
        resume_file_group_layout = QtWidgets.QHBoxLayout()
        self.resume_file_group.setLayout(resume_file_group_layout)
        resume_file_group_layout.addWidget(self.resume_file)
        resume_group_layout.addWidget(self.resume_file_group)
        self.resume_file_group.hide()

        tab_dataset_layout.addStretch()
        tab_network_layout.addStretch()
        tab_resume_layout.addStretch()

        button_box = QtWidgets.QDialogButtonBox()
        self.training_btn = button_box.addButton(
            _('Start Training'), QtWidgets.QDialogButtonBox.AcceptRole)
        self.training_btn.clicked.connect(self.training_btn_clicked)
        cancel_btn = button_box.addButton(
            _('Cancel'), QtWidgets.QDialogButtonBox.RejectRole)
        cancel_btn.clicked.connect(self.cancel_btn_clicked)
        layout.addWidget(button_box)

        h = self.sizeHint().height()
        self.resize(500, h)