예제 #1
0
 def update_results(self):
     data = Map(self.data)
     min_score = self.score_slider.value() / 100.0
     self.score_value.setText('{}%'.format(self.score_slider.value()))
     self.reset_image()
     for i in range(len(data.bbox[0])):
         label = int(data.classid[0][i][0])
         score = data.score[0][i][0]
         if label > -1 and score > min_score:
             label_name = _('unknown')
             if label < len(data.labels):
                 label_name = str(data.labels[label])
             xr = self.pixmap.width() / data.imgsize[1]
             yr = self.pixmap.height() / data.imgsize[0]
             x, y = data.bbox[0][i][0] * xr, data.bbox[0][i][1] * yr
             w, h = data.bbox[0][i][2] * xr - x, data.bbox[0][i][3] * yr - y
             #logger.debug('Draw bbox ({}, {}, {}, {}) for label {} ({})'.format(int(x), int(y), int(w), int(h), label_name, label))
             self.painter.scale(1, 1)
             self.painter.drawRect(x, y, w, h)
             p1 = QtCore.QPointF(x + 4, y + 12)
             p2 = QtCore.QPointF(x + 4, y + 24)
             self.painter.drawText(p1, label_name)
             self.painter.drawText(p2, '{0:.4f}'.format(score))
     self.image_label.setPixmap(self.pixmap)
     self.image_label.show()
예제 #2
0
 def loadConfig(self, config_file):
     logger.debug('Load training config from file: {}'.format(config_file))
     with open(config_file, 'r') as f:
         data = json.load(f)
         logger.debug('Loaded training config: {}'.format(data))
         return Map(data)
     raise Exception(
         'Could not load training config from file {}'.format(config_file))
예제 #3
0
    def makeRecFile(self, data_folder, output_folder, file_name_rec,
                    file_name_idx, file_name_lst):
        image_list = self.readLstFile(
            os.path.join(output_folder, file_name_lst))
        record = mx.recordio.MXIndexedRecordIO(
            os.path.join(output_folder, file_name_idx),
            os.path.join(output_folder, file_name_rec), 'w')

        self.checkAborted()

        args = Map({
            'root': data_folder,
            'pass_through': True,
            'resize': 0,
            'center_crop': False,
            'quality': 95,
            'encoding': '.jpg',
            'pack_label': True,
        })
        try:
            import Queue as queue
        except ImportError:
            import queue
        q_out = queue.Queue()
        cnt = 0
        pre_time = time.time()
        failed_images = []
        for i, item in enumerate(image_list):
            try:
                self.imageEncode(args, i, item, q_out)
                if q_out.empty():
                    continue
                _a, s, _b = q_out.get()
                record.write_idx(item[0], s)
                if cnt % 1000 == 0:
                    cur_time = time.time()
                    logger.debug('time: {} count: {}'.format(
                        cur_time - pre_time, cnt))
                    pre_time = cur_time
                cnt += 1
                self.thread.update.emit(_('Writing dataset ...'), -1, -1)
                self.checkAborted()

            except Exception as e:
                failed_images.append(item)
                logger.error(traceback.format_exc())

        if len(failed_images) > 0:
            msg = _('The following images could not be exported:'
                    ) + '\n' + ', '.join(failed_images)
            self.thread.message.emit(_('Warning'), msg, MessageType.Warning)
            if cnt == 0:
                self.throwUserException(
                    _('Dataset contains no images for export'))

        logger.debug('total time: {} total count: {}'.format(
            time.time() - pre_time, cnt))
예제 #4
0
    def run(self):
        logger.debug('Prepare inference')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        training_folder = self.data['training_folder']
        input_image_file = self.data['input_image_file']

        config_file = os.path.join(training_folder,
                                   Training.config('config_file'))

        network = Network()
        network.setAbortable(self.abortable)
        network.setThread(self.thread)

        network_config = network.loadConfig(config_file)
        self.thread.data.emit({'network_config': network_config})

        architecture_file = ''
        weights_file = ''
        files = network_config.files
        for f in files:
            if '.json' in f:
                architecture_file = os.path.join(training_folder, f)
            elif '.params' in f:
                weights_file = os.path.join(training_folder, f)

        dataset_folder = network_config.dataset

        inference_data = Map({
            'input_image_file': input_image_file,
            'architecture_file': architecture_file,
            'weights_file': weights_file,
            'labels': network_config.labels,
        })

        self.thread.update.emit(_('Validating ...'), 1, 3)

        network.inference(inference_data.input_image_file,
                          inference_data.labels,
                          inference_data.architecture_file,
                          inference_data.weights_file,
                          args=None)

        self.thread.update.emit(_('Finished'), 3, 3)
예제 #5
0
    def run(self):
        logger.debug('Prepare training')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        network_key = self.data['network']
        if network_key not in Training.config('objects'):
            self.thread.message.emit(
                _('Training'),
                _('Network {} could not be found').format(network_key),
                MessageType.Error)
            self.abort()
            return

        # Training settings
        gpus = []
        gpu_checkboxes = self.data['gpu_checkboxes']
        for i, gpu in enumerate(gpu_checkboxes):
            if gpu.checkState() == Qt.Checked:
                gpus.append(str(i))
        gpus = ','.join(gpus)
        epochs = int(self.data['args_epochs'])
        batch_size = int(self.data['args_batch_size'])

        # Dataset
        dataset_format = self.data['selected_format']
        train_dataset_obj = Export.config('objects')[dataset_format]()
        train_dataset_obj.setInputFolderOrFile(self.data['train_dataset'])
        if self.data['val_dataset']:
            val_dataset_obj = Export.config('objects')[dataset_format]()
            val_dataset_obj.setInputFolderOrFile(self.data['val_dataset'])

        labels = train_dataset_obj.getLabels()
        num_train_samples = train_dataset_obj.getNumSamples()
        num_batches = int(math.ceil(num_train_samples / batch_size))

        args = Map({
            'network':
            self.data['network'],
            'train_dataset':
            self.data['train_dataset'],
            'validate_dataset':
            self.data['val_dataset'],
            'training_name':
            self.data['training_name'],
            'batch_size':
            batch_size,
            'learning_rate':
            float(self.data['args_learning_rate']),
            'gpus':
            gpus,
            'epochs':
            epochs,
            'early_stop_epochs':
            int(self.data['args_early_stop_epochs']),
            'start_epoch':
            self.data['start_epoch'],
            'resume':
            self.data['resume_training'],
        })

        self.thread.update.emit(_('Loading data ...'), 0,
                                epochs * num_batches + 5)

        with Training.config('objects')[network_key]() as network:
            network.setAbortable(self.abortable)
            network.setThread(self.thread)
            network.setArgs(args)
            network.setOutputFolder(self.data['output_folder'])
            network.setTrainDataset(train_dataset_obj, dataset_format)
            network.setLabels(labels)

            if self.data['val_dataset']:
                network.setValDataset(val_dataset_obj)

            self.checkAborted()

            network.training()
예제 #6
0
    def on_data(self, data):
        self.data = data
        data = Map(data)
        try:
            if 'validation' in data:
                num_new_items = len(data.validation.items())
                num_old_items = len(self.validation_values)

                large_font = QtGui.QFont('Arial', 10, QtGui.QFont.Normal)

                row = 0
                for item in data.validation.items():
                    if row >= num_old_items:
                        label = QtWidgets.QLabel(str(item[0]))
                        label.setFont(large_font)
                        value = QtWidgets.QLabel('-')
                        value.setFont(large_font)
                        value.setAlignment(QtCore.Qt.AlignRight
                                           | QtCore.Qt.AlignVCenter)
                        self.validation_group_layout.addWidget(label, row, 0)
                        self.validation_group_layout.addWidget(value, row, 1)
                        self.validation_labels.append(label)
                        self.validation_values.append(value)
                    self.validation_labels[row].setText(str(item[0]))
                    if isinstance(item[1], (int, float)):
                        self.validation_values[row].setText('{:.4f}'.format(
                            item[1]))
                    else:
                        self.validation_values[row].setText(str(item[1]))
                    row += 1

                if num_old_items > num_new_items:
                    for i in range(num_new_items, num_old_items):
                        self.validation_labels[i].setText('')
                        self.validation_values[i].setText('')
                        self.validation_group_layout.removeWidget(
                            self.validation_labels[i])
                        self.validation_group_layout.removeWidget(
                            self.validation_values[i])
                    for i in range(num_old_items, num_new_items, -1):
                        del self.validation_labels[i - 1]
                        del self.validation_values[i - 1]

            if 'progress' in data:
                progress = Map(data.progress)
                if 'epoch' in progress:
                    self.epoch_value.setText(str(progress.epoch))
                if 'epoch_max' in progress:
                    self.epoch_max_value.setText(str(progress.epoch_max))
                if 'batch' in progress:
                    self.batch_value.setText(str(progress.batch))
                if 'batch_max' in progress:
                    self.batch_max_value.setText(str(progress.batch_max))
                if 'speed' in progress:
                    self.speed_value.setText('{:.2f} {}'.format(
                        progress.speed, _('samples/sec')))

                if not self.training_has_started:
                    self.start_time = time.time()
                    self.training_has_started = True

                if 'metric' in progress:
                    num_new_items = len(progress.metric.items())
                    num_old_items = len(self.metric_values)

                    row = 0
                    for item in progress.metric.items():
                        if row >= num_old_items:
                            label = QtWidgets.QLabel(str(item[0]))
                            value = QtWidgets.QLabel('-')
                            value.setAlignment(QtCore.Qt.AlignRight
                                               | QtCore.Qt.AlignVCenter)
                            self.metric_group_layout.addWidget(label, row, 0)
                            self.metric_group_layout.addWidget(value, row, 1)
                            self.metric_labels.append(label)
                            self.metric_values.append(value)
                        self.metric_labels[row].setText(str(item[0]))
                        if isinstance(item[1], (int, float)):
                            self.metric_values[row].setText('{:.4f}'.format(
                                item[1]))
                        else:
                            self.metric_values[row].setText(str(item[1]))
                        row += 1

                    if num_old_items > num_new_items:
                        for i in range(num_new_items, num_old_items):
                            self.metric_labels[i].setText('')
                            self.metric_values[i].setText('')
                            self.metric_group_layout.removeWidget(
                                self.metric_labels[i])
                            self.metric_group_layout.removeWidget(
                                self.metric_values[i])
                        for i in range(num_old_items, num_new_items, -1):
                            del self.metric_labels[i - 1]
                            del self.metric_values[i - 1]

                # Estimate finish time
                if self.start_time is not False:
                    if 'epoch' in progress and 'epoch_max' in progress and 'batch' in progress and 'batch_max' in progress:
                        percentage = (
                            (progress.epoch - 1) / progress.epoch_max) + (
                                progress.batch / progress.batch_max /
                                progress.epoch_max)
                        current_time = time.time()
                        duration = current_time - self.start_time
                        seconds_left = (duration / percentage) - duration
                        self.finished_value.setText(
                            self.format_duration(seconds_left))

        except Exception as e:
            logger.error(traceback.format_exc())
예제 #7
0
 def setArgs(self, args):
     default_args = self.getDefaultArgs()
     self.args = default_args.copy()
     self.args.update(args)
     self.args = Map(self.args)
     logger.debug(self.args)
예제 #8
0
    def inference(self,
                  input_image_file,
                  labels,
                  architecture_file,
                  weights_file,
                  args=None):
        default_args = {
            'threshold': 0.5,
            'print_top_n': 10,
        }
        tmp_args = default_args.copy()
        if args:
            tmp_args.update(args)
        args = Map(tmp_args)
        logger.debug('Try loading network from files "{}" and "{}"'.format(
            architecture_file, weights_file))

        self.checkAborted()

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            ctx = self.getContext()
            net = gluon.nn.SymbolBlock.imports(architecture_file, ['data'],
                                               weights_file,
                                               ctx=ctx)
            class_names = labels
            net.collect_params().reset_ctx(ctx)
            img = mx.image.imread(input_image_file)
            img = timage.resize_short_within(img,
                                             608,
                                             max_size=1024,
                                             mult_base=1)

            self.checkAborted()
            self.thread.update.emit(None, -1, -1)

            def make_tensor(img):
                np_array = np.expand_dims(np.transpose(img, (0, 1, 2)),
                                          axis=0).astype(np.float32)
                return mx.nd.array(np_array)

            image = img.asnumpy().astype('uint8')
            x = make_tensor(image)
            cid, score, bbox = net(x)

            self.thread.data.emit({
                'files': {
                    'input_image_file': input_image_file,
                    'architecture_file': architecture_file,
                    'weights_file': weights_file,
                },
                'imgsize': [image.shape[0], image.shape[1]],
                'classid': cid.asnumpy().tolist(),
                'score': score.asnumpy().tolist(),
                'bbox': bbox.asnumpy().tolist(),
                'labels': labels,
            })
            self.thread.update.emit(None, -1, -1)

            n_top = args.print_top_n
            classes = cid[0][:n_top].asnumpy().astype(
                'int32').flatten().tolist()
            scores = score[0][:n_top].asnumpy().astype(
                'float32').flatten().tolist()
            result_str = '\n'.join([
                'class: {}, score: {}'.format(classes[i], scores[i])
                for i in range(n_top)
            ])
            logger.debug('Top {} inference results:\n {}'.format(
                n_top, result_str))
예제 #9
0
class Network(WorkerExecutor):
    def __init__(self):
        super().__init__()
        self.monitor = NetworkMonitor()
        self.dataset_format = None
        self.integration_network_name = 'custom'
        self.net_name = None
        self.model_file_name = None
        self.network = 'network'
        self.files = {
            'architecture':
            '{}-symbol.json'.format(self.integration_network_name),
            'weights': '{}-0000.params'.format(self.integration_network_name),
        }

    def training(self):
        gutils.random.seed(self.args.seed)

        # Prepare network and data
        self.args.save_prefix += self.net_name
        if not self.args.validate_dataset:
            self.args.val_interval = sys.maxsize
        self.labels = self.train_dataset.getLabels()

        self.thread.update.emit(_('Loading model ...'), None, -1)
        self.loadModel()

        self.thread.update.emit(_('Loading dataset ...'), None, -1)
        self.loadDataset()

        self.thread.update.emit(_('Start training ...'), None, -1)
        self.beforeTrain()
        last_full_epoch = self.trainBase()
        self.afterTrain(last_full_epoch)
        self.thread.update.emit(_('Finished training'), None, -1)

    def getDefaultArgs(self):
        default_args = {
            'training_name': 'unknown',
            'train_dataset': '',
            'validate_dataset': '',
            'data_shape': 0,
            'batch_size': 8,
            'gpus': '0',
            'epochs': 10,
            'resume': '',
            'start_epoch': 0,
            'num_workers': 0,
            'learning_rate': self.getDefaultLearningRate(),
            'lr_decay': 0.1,
            'lr_decay_epoch': '160,180',
            'momentum': 0.9,
            'wd': 0.0005,
            'log_interval': 1,
            'save_prefix': '',
            'save_interval': 1,
            'val_interval': 1,
            'seed': 42,
            'num_samples': -1,
            'syncbn': False,
            'mixup': False,
            'no_mixup_epochs': 20,
            'early_stop_epochs': 0,
        }
        return default_args

    def setArgs(self, args):
        default_args = self.getDefaultArgs()
        self.args = default_args.copy()
        self.args.update(args)
        self.args = Map(self.args)
        logger.debug(self.args)

    def loadDataset(self):
        train_dataset = self.train_dataset.getDatasetForTraining()
        val_dataset = None
        if self.args.validate_dataset:
            val_dataset = self.val_dataset.getDatasetForTraining()

        self.eval_metric = self.getValidationMetric()

        if self.args.num_samples < 0:
            self.args.num_samples = len(train_dataset)
        if self.args.mixup:
            from gluoncv.data import MixupDetection
            train_dataset = MixupDetection(train_dataset)

        self.train_data, self.val_data = self.getDataloader(
            train_dataset, val_dataset)

    def getDataloader(self, train_dataset, val_dataset):
        raise NotImplementedError(
            'Method getDataloader() needs to be implemented in subclasses')
        # return train_loader, val_loader (both of type gluon.data.DataLoader)

    def loadModel(self):
        self.ctx = self.getContext(self.args.gpus)
        self.net = gcv.model_zoo.get_model(self.net_name,
                                           pretrained=False,
                                           ctx=self.ctx)
        if self.args.resume.strip():
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                self.net.initialize(ctx=self.ctx)
            self.net.reset_class(self.labels)
            self.net.load_parameters(self.args.resume.strip(), ctx=self.ctx)
        else:
            model_path = os.path.normpath(
                os.path.join(os.path.dirname(__file__),
                             '../../../networks/models'))
            weights_file = os.path.join(model_path, self.model_file_name)
            self.net.load_parameters(weights_file, ctx=self.ctx)
            self.net.reset_class(self.labels)
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                self.net.initialize()

    def trainBase(self):
        try:
            last_full_epoch = self.train()
            return last_full_epoch
        except AbortTrainingException as e:
            return e.epoch

    def train(self):
        raise NotImplementedError(
            'Method train() needs to be implemented in subclasses')
        # return last_full_epoch

    def checkTrainingAborted(self, epoch):
        if self.isAborted():
            self.checkAborted()

    def beforeTrain(self):
        # Save config & architecture before training
        from labelme.config import Training
        config_file = os.path.join(self.output_folder,
                                   Training.config('config_file'))
        files = list(self.files.values())
        self.saveConfig(config_file, files)

        if self.args.early_stop_epochs > 0:
            self.monitor = NetworkMonitor(self.args.early_stop_epochs)

        self.thread.data.emit({
            'validation': {
                _('Waiting for first validation ...'): '',
            },
        })

        # Set maximum progress to 100%
        self.thread.update.emit(_('Start training ...'), 0, 100)

        self.checkAborted()
        logger.info('Start training from [Epoch {}]'.format(
            self.args.start_epoch))

    def afterTrain(self, last_full_epoch):
        self.thread.update.emit(_('Finished training'), 100, -1)

    def beforeEpoch(self, epoch, num_batches):
        self.thread.update.emit(
            _('Start training on epoch {} ...').format(epoch + 1), None, -1)
        self.thread.data.emit({
            'progress': {
                'epoch': epoch + 1,
                'epoch_max': self.args.epochs,
                'batch': 1,
                'batch_max': num_batches,
                'speed': 0,
            },
        })
        self.checkTrainingAborted(epoch)

    def afterEpoch(self, epoch):
        from labelme.config import Training
        config_file = os.path.join(self.output_folder,
                                   Training.config('config_file'))
        self.updateConfig(config_file, last_epoch=epoch + 1)
        self.thread.update.emit(
            _('Finished training on epoch {}').format(epoch + 1), None, -1)
        self.checkTrainingAborted(epoch)

    def beforeBatch(self, batch_idx, epoch, num_batches):
        self.checkTrainingAborted(epoch)

    def afterBatch(self, batch_idx, epoch, num_batches, learning_rate, speed,
                   metrics):
        i = batch_idx
        if self.args.log_interval and not (i + 1) % self.args.log_interval:
            log_msg = '[Epoch {}/{}][Batch {}/{}], LR: {:.2E}, Speed: {:.3f} samples/sec'.format(
                epoch + 1, self.args.epochs, i + 1, num_batches, learning_rate,
                speed)
            update_msg = '{}\n{} {}, {} {}/{}, {}: {:.3f} {}\n'.format(
                _('Training ...'), _('Epoch'), epoch + 1, _('Batch'), i + 1,
                num_batches, _('Speed'), speed, _('samples/sec'))
            progress_metrics = {}
            for metric in metrics:
                name, loss = metric.get()
                msg = ', {}={:.3f}'.format(name, loss)
                log_msg += msg
                update_msg += msg
                progress_metrics[name] = loss
            logger.info(log_msg)

            self.thread.data.emit({
                'progress': {
                    'epoch': epoch + 1,
                    'epoch_max': self.args.epochs,
                    'batch': i + 1,
                    'batch_max': num_batches,
                    'speed': speed,
                    'metric': progress_metrics
                },
            })

            percent = math.ceil(
                (epoch / self.args.epochs +
                 batch_idx / num_batches / self.args.epochs) * 100)
            self.thread.update.emit(update_msg, percent, -1)

        self.checkTrainingAborted(epoch)

    def validateEpoch(self, epoch, epoch_time, validate_params):
        self.checkTrainingAborted(epoch)
        if self.val_data and not (epoch + 1) % self.args.val_interval:
            logger.debug('validate: {}'.format(epoch + 1))
            self.thread.data.emit({
                'validation': {
                    _('Validating...'): '',
                },
                'progress': {
                    'speed': 0,
                }
            })

            map_name, mean_ap = self.validate(**validate_params)
            val_msg = '\n'.join(
                ['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
            logger.info('[Epoch {}] Validation [{:.3f}sec]: \n{}'.format(
                epoch, epoch_time, val_msg))
            current_mAP = float(mean_ap[-1])

            val_data = {'validation': {}}
            for i, name in enumerate(map_name[:]):
                val_data['validation'][name] = mean_ap[i]
            self.thread.data.emit(val_data)

            # Early Stopping
            self.monitor.update(epoch, mean_ap[-1])
            if self.monitor.shouldStopEarly():
                raise AbortTrainingException(epoch)
        else:
            current_mAP = 0.
        return current_mAP

    def saveParams(self, best_map, current_map, epoch):
        current_map = float(current_map)
        prefix = os.path.join(self.output_folder, self.args.save_prefix)
        if current_map >= best_map[0]:
            best_map[0] = current_map
            # Save custom-0000.params and custom-symbol.json for Investigator integration
            self.saveTraining(self.integration_network_name, 0)
            #self.net.save_parameters('{:s}_best.params'.format(prefix, epoch, current_map))
            with open(prefix + '_best_map.log', 'a') as f:
                f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
        if self.args.save_interval and epoch % self.args.save_interval == 0:
            self.saveTraining(
                '{:s}_{:04d}_{:.4f}'.format(prefix, epoch, current_map), epoch)
            #self.net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))

    def validate(self, waitall=False, static_shape=False):
        self.eval_metric.reset()
        self.net.set_nms(nms_thresh=0.45, nms_topk=400)
        if waitall:
            mx.nd.waitall()
        self.net.hybridize(static_alloc=static_shape,
                           static_shape=static_shape)
        for batch in self.val_data:
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=self.ctx,
                                              batch_axis=0,
                                              even_split=False)
            label = gluon.utils.split_and_load(batch[1],
                                               ctx_list=self.ctx,
                                               batch_axis=0,
                                               even_split=False)
            det_bboxes = []
            det_ids = []
            det_scores = []
            gt_bboxes = []
            gt_ids = []
            gt_difficults = []
            for x, y in zip(data, label):
                # get prediction results
                ids, scores, bboxes = self.net(x)
                det_ids.append(ids)
                det_scores.append(scores)
                # clip to image size
                det_bboxes.append(bboxes.clip(0, batch[0].shape[2]))
                # split ground truths
                gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
                gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
                gt_difficults.append(
                    y.slice_axis(axis=-1, begin=5, end=6
                                 ) if y.shape[-1] > 5 else None)
            # update metric
            self.eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes,
                                    gt_ids, gt_difficults)
        return self.eval_metric.get()

    def getValidationMetric(self, iou_thresh=0.5):
        val_metric = VOC07MApMetric(iou_thresh=iou_thresh,
                                    class_names=self.labels)
        #val_metric = VOCMApMetric(iou_thresh=iou_thresh, class_names=self.labels)
        #val_metric = COCODetectionMetric(iou_thresh=iou_thresh, class_names=self.labels)
        return val_metric

    def getGpuSizes(self):
        raise NotImplementedError(
            'Method getGpuSizes() needs to be implemented in subclasses')

    def getDefaultLearningRate(self):
        raise NotImplementedError(
            'Method getDefaultLearningRate() needs to be implemented in subclasses'
        )

    def updateConfig(self, config_file, **kwargs):
        logger.debug('Update training config with data: {}'.format(kwargs))
        data = {}
        with open(config_file, 'r') as f:
            data = json.load(f)
        for key in kwargs.keys():
            data[key] = kwargs[key]
        with open(config_file, 'w+') as f:
            json.dump(data, f, indent=2)
        logger.debug('Updated training config in file: {}'.format(config_file))

    def saveConfig(self, config_file, files):
        args = self.args.copy()
        del args['network']
        del args['resume']
        data = {
            'network': self.args.network,
            'files': files,
            'dataset': self.dataset_format,
            'labels': self.labels,
            'args': args,
        }
        logger.debug('Create training config: {}'.format(data))
        with open(config_file, 'w+') as f:
            json.dump(data, f, indent=2)
            logger.debug(
                'Saved training config in file: {}'.format(config_file))

    def loadConfig(self, config_file):
        logger.debug('Load training config from file: {}'.format(config_file))
        with open(config_file, 'r') as f:
            data = json.load(f)
            logger.debug('Loaded training config: {}'.format(data))
            return Map(data)
        raise Exception(
            'Could not load training config from file {}'.format(config_file))

    def setOutputFolder(self, output_folder):
        self.output_folder = output_folder

    def setLabels(self, labels):
        self.labels = labels

    def setTrainDataset(self, dataset, dataset_format):
        self.train_dataset = dataset
        self.dataset_format = dataset_format

    def setValDataset(self, dataset):
        self.val_dataset = dataset

    def getContext(self, gpus=None):
        if gpus is None or gpus == '':
            return [mx.cpu()]
        ctx = [mx.gpu(int(i)) for i in gpus.split(',') if i.strip()]
        try:
            tmp = mx.nd.array([1, 2, 3], ctx=ctx[0])
        except mx.MXNetError as e:
            ctx = [mx.cpu()]
            logger.error(traceback.format_exc())
            logger.warning('Unable to use GPU. Using CPU instead')
        logger.debug('Use context: {}'.format(ctx))
        return ctx

    def saveTraining(self, network_name, epoch=0):
        # Export weights to .params file
        export_block(os.path.join(self.output_folder, network_name),
                     self.net,
                     epoch=epoch,
                     preprocess=True,
                     layout='HWC',
                     ctx=self.ctx[0])

    def inference(self,
                  input_image_file,
                  labels,
                  architecture_file,
                  weights_file,
                  args=None):
        default_args = {
            'threshold': 0.5,
            'print_top_n': 10,
        }
        tmp_args = default_args.copy()
        if args:
            tmp_args.update(args)
        args = Map(tmp_args)
        logger.debug('Try loading network from files "{}" and "{}"'.format(
            architecture_file, weights_file))

        self.checkAborted()

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            ctx = self.getContext()
            net = gluon.nn.SymbolBlock.imports(architecture_file, ['data'],
                                               weights_file,
                                               ctx=ctx)
            class_names = labels
            net.collect_params().reset_ctx(ctx)
            img = mx.image.imread(input_image_file)
            img = timage.resize_short_within(img,
                                             608,
                                             max_size=1024,
                                             mult_base=1)

            self.checkAborted()
            self.thread.update.emit(None, -1, -1)

            def make_tensor(img):
                np_array = np.expand_dims(np.transpose(img, (0, 1, 2)),
                                          axis=0).astype(np.float32)
                return mx.nd.array(np_array)

            image = img.asnumpy().astype('uint8')
            x = make_tensor(image)
            cid, score, bbox = net(x)

            self.thread.data.emit({
                'files': {
                    'input_image_file': input_image_file,
                    'architecture_file': architecture_file,
                    'weights_file': weights_file,
                },
                'imgsize': [image.shape[0], image.shape[1]],
                'classid': cid.asnumpy().tolist(),
                'score': score.asnumpy().tolist(),
                'bbox': bbox.asnumpy().tolist(),
                'labels': labels,
            })
            self.thread.update.emit(None, -1, -1)

            n_top = args.print_top_n
            classes = cid[0][:n_top].asnumpy().astype(
                'int32').flatten().tolist()
            scores = score[0][:n_top].asnumpy().astype(
                'float32').flatten().tolist()
            result_str = '\n'.join([
                'class: {}, score: {}'.format(classes[i], scores[i])
                for i in range(n_top)
            ])
            logger.debug('Top {} inference results:\n {}'.format(
                n_top, result_str))

            #ax = viz.plot_bbox(image, bbox[0], score[0], cid[0], class_names=class_names, thresh=args.threshold)
            #plt.show()

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        if self.ctx:
            for context in self.ctx:
                context.empty_cache()
예제 #10
0
    def run(self):
        logger.debug('Start import from directory')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        data = Map(self.data)
        num_images = len(data.images)
        pattern = data.pattern
        output_dir = data.output_dir
        filters = data.filters

        filter_label_func = self.acceptAll
        if 'label' in filters and not filters['label'] == StatisticsModel.STATISTICS_FILTER_ALL:
            filter_label_func = self.acceptLabel

        image_count = 0
        all_shapes = []
        items = []

        self.checkAborted()

        for i, filename in enumerate(data.images):

            self.thread.update.emit(None, i, num_images)
            self.checkAborted()

            # Search pattern
            if pattern and pattern.lower() not in filename.lower(): # re.search(pattern, filename, re.IGNORECASE) == None:
                continue

            label_file = os.path.splitext(filename)[0] + '.json'
            if output_dir:
                label_file_without_path = os.path.basename(label_file)
                label_file = os.path.normpath(os.path.join(output_dir, label_file_without_path))

            # ListItem
            item = QtWidgets.QListWidgetItem(filename)
            item.setFlags(Qt.ItemIsEnabled | Qt.ItemIsSelectable)
            item.setCheckState(Qt.Unchecked)

            self.checkAborted()

            shapes = []
            has_labels = False
            labels_for_image = set([])
            label_file_exists = os.path.isfile(label_file)

            # Labels
            if label_file_exists:
                labelFile = LabelFile(label_file)
                for label, points, line_color, fill_color, shape_type, flags in labelFile.shapes:
                    if filter_label_func(label):
                        has_labels = True
                        shape = Shape(label=label, shape_type=shape_type)
                        shapes.append(shape)
                        labels_for_image.add(label)

            # Filters
            if 'label' in filters and not filters['label'] == StatisticsModel.STATISTICS_FILTER_ALL:
                if not filters['label'] in labels_for_image:
                    continue
            if 'has_label' in filters:
                if filters['has_label'] == StatisticsModel.STATISTICS_FILTER_LABELED and not has_labels:
                    continue
                if filters['has_label'] == StatisticsModel.STATISTICS_FILTER_UNLABELED and has_labels:
                    continue

            image_count += 1
            items.append(item)
            if has_labels:
                item.setCheckState(Qt.Checked)
                all_shapes.append(shapes)

            if image_count % data['update_interval'] == 0:
                self.thread.data.emit({
                    'items': items,
                    'num_images': image_count,
                    'all_shapes': all_shapes,
                })
                image_count = 0
                all_shapes = []
                items = []

            self.checkAborted()

        self.thread.data.emit({
            'num_images': image_count,
            'all_shapes': all_shapes,
            'items': items,
        })
예제 #11
0
    def run(self):
        logger.debug('Prepare export')

        try:
            import ptvsd
            ptvsd.debug_this_thread()
        except:
            pass

        data_folder = self.data['data_folder']
        is_data_folder_valid = True
        if not data_folder:
            is_data_folder_valid = False
        data_folder = os.path.normpath(data_folder)
        if not os.path.isdir(data_folder):
            is_data_folder_valid = False
        if not is_data_folder_valid:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid data folder'),
                                     MessageType.Warning)
            self.abort()
            return

        export_folder = self.data['export_folder']
        is_export_folder_valid = True
        if not export_folder:
            is_export_folder_valid = False
        export_folder = os.path.normpath(export_folder)
        if not os.path.isdir(export_folder):
            is_export_folder_valid = False
        if not is_export_folder_valid:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid export folder'),
                                     MessageType.Warning)
            self.abort()
            return

        selected_labels = self.data['selected_labels']
        num_selected_labels = len(selected_labels)
        limit = self.data['max_num_labels']
        if num_selected_labels > limit:
            self.thread.message.emit(
                _('Export'),
                _('Please select a maximum of {} labels').format(limit),
                MessageType.Warning)
            self.abort()
            return
        elif num_selected_labels <= 0:
            self.thread.message.emit(_('Export'),
                                     _('Please select at least 1 label'),
                                     MessageType.Warning)
            self.abort()
            return

        dataset_name = replace_special_chars(self.data['dataset_name'])
        if not dataset_name:
            self.thread.message.emit(_('Export'),
                                     _('Please enter a valid dataset name'),
                                     MessageType.Warning)
            self.abort()
            return

        export_dataset_folder = os.path.normpath(
            os.path.join(self.data['export_folder'],
                         self.data['dataset_name']))
        if not os.path.isdir(export_dataset_folder):
            os.makedirs(export_dataset_folder)
        elif len(os.listdir(export_dataset_folder)) > 0:
            msg = _(
                'The selected output directory "{}" is not empty. All containing files will be deleted. Are you sure to continue?'
            ).format(export_dataset_folder)
            if self.doConfirm(_('Export'), msg, MessageType.Warning):
                deltree(export_dataset_folder)
                time.sleep(0.5)  # wait for deletion to be finished
                if not os.path.exists(export_dataset_folder):
                    os.makedirs(export_dataset_folder)
            else:
                self.abort()
                return

        if not os.path.isdir(export_dataset_folder):
            self.thread.message.emit(
                _('Export'),
                _('The selected output directory "{}" could not be created').
                format(export_dataset_folder), MessageType.Warning)
            self.abort()
            return

        selected_format = self.data['selected_format']
        all_formats = Export.config('formats')
        inv_formats = Export.invertDict(all_formats)
        if selected_format not in inv_formats:
            self.thread.message.emit(
                _('Export'),
                _('Export format {} could not be found').format(
                    selected_format), MessageType.Warning)
            self.abort()
            return
        else:
            self.data['format_name'] = inv_formats[selected_format]

        logger.debug('Start export')

        selected_labels = self.data['selected_labels']
        validation_ratio = self.data['validation_ratio']
        data_folder = self.data['data_folder']
        format_name = self.data['format_name']

        self.checkAborted()

        intermediate = IntermediateFormat()
        intermediate.setAbortable(self.abortable)
        intermediate.setThread(self.thread)
        intermediate.setIncludedLabels(selected_labels)
        intermediate.setValidationRatio(validation_ratio)
        intermediate.addFromLabelFiles(data_folder, shuffle=False)

        self.thread.update.emit(_('Loading data ...'), 0,
                                intermediate.getNumberOfSamples() + 5)

        args = Map({
            'validation_ratio': validation_ratio,
        })

        dataset_format = Export.config('objects')[format_name]()
        dataset_format.setAbortable(self.abortable)
        dataset_format.setThread(self.thread)
        dataset_format.setIntermediateFormat(intermediate)
        dataset_format.setInputFolderOrFile(data_folder)
        dataset_format.setOutputFolder(export_dataset_folder)
        dataset_format.setArgs(args)

        self.checkAborted()

        dataset_format.export()