示例#1
0
    def _load_train_config(self):
        train_config_rw = TrainConfigRW(self.helper.paths.model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError(
                'Unable to run inference, config from training wasn\'t found.')
        self.train_config = train_config_rw.load()

        src_size = self.train_config['settings']['input_size']
        self.input_size_wh = (src_size['width'], src_size['height'])
        logger.info('Model input size is read (for auto-rescale).',
                    extra={
                        'input_size': {
                            'width': self.input_size_wh[0],
                            'height': self.input_size_wh[1]
                        }
                    })

        self.class_title_to_idx = self.train_config['class_title_to_idx']
        self.train_classes = sly.FigClasses(self.train_config['out_classes'])
        logger.info('Read model internal class mapping',
                    extra={'class_mapping': self.class_title_to_idx})
        logger.info('Read model out classes',
                    extra={'classes': self.train_classes.py_container})

        self.out_class_mapping = {
            x: self.class_title_to_idx[x]
            for x in (x['title'] for x in self.train_classes)
        }
示例#2
0
    def _check_prev_model_config(self):
        prev_model_dir = self.helper.paths.model_dir
        prev_config_rw = TrainConfigRW(prev_model_dir)
        if not prev_config_rw.train_config_exists:
            raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.')
        prev_config = prev_config_rw.load()

        old_class_mapping = prev_config.get('class_title_to_idx', {})
        if self.class_title_to_idx != old_class_mapping:
            raise RuntimeError('Unable to continue training, class mapping is inconsistent with previous model.')
示例#3
0
    def _dump_model(self, is_best, opt_data):
        out_dir = self.helper.checkpoints_saver.get_dir_to_write()
        TrainConfigRW(out_dir).save(self.out_config)

        model_fpath = os.path.join(out_dir, 'model.ckpt')
        self.saver.save(self.sess, model_fpath)

        self.helper.checkpoints_saver.saved(is_best, opt_data)
    def _load_train_config(self):
        model_dir = sly.TaskPaths(determine_in_project=False).model_dir

        train_config_rw = TrainConfigRW(model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError('Unable to run inference, config from training wasn\'t found.')
        self.train_config = train_config_rw.load()

        src_size = 713 # @TODO: fixed value
        self.input_size_wh = (src_size, src_size)
        logger.info('Model input size is read (for auto-rescale).', extra={'input_size': {
            'width': self.input_size_wh[0], 'height': self.input_size_wh[1]
        }})

        self.class_title_to_idx = self.train_config['class_title_to_idx']
        self.train_classes = sly.FigClasses(self.train_config['out_classes'])
        logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx})
        logger.info('Read model out classes', extra={'classes': self.train_classes.py_container})

        self.out_class_mapping = {x: self.class_title_to_idx[x] for x in
                                  (x['title'] for x in self.train_classes)}