Exemplo n.º 1
0
    def _load_train_config(self):
        model_dir = sly.TaskPaths(determine_in_project=False).model_dir

        train_config_rw = TrainConfigRW(model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError(
                'Unable to run inference, config from training wasn\'t found.')
        self.train_config = train_config_rw.load()

        src_size = self.train_config['settings']['input_size']
        self.input_size_wh = (src_size['width'], src_size['height'])
        logger.info('Model input size is read (for auto-rescale).',
                    extra={
                        'input_size': {
                            'width': self.input_size_wh[0],
                            'height': self.input_size_wh[1]
                        }
                    })

        self.class_title_to_idx = self.train_config['class_title_to_idx']
        self.train_classes = sly.FigClasses(self.train_config['out_classes'])
        logger.info('Read model internal class mapping',
                    extra={'class_mapping': self.class_title_to_idx})
        logger.info('Read model out classes',
                    extra={'classes': self.train_classes.py_container})

        self.out_class_mapping = {
            x: self.class_title_to_idx[x]
            for x in (x['title'] for x in self.train_classes)
        }
Exemplo n.º 2
0
    def _load_train_config(self):
        train_config_rw = TrainConfigRW(self.helper.paths.model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError(
                'Unable to run inference, config from training wasn\'t found.')
        self.train_config = train_config_rw.load()

        logger.info(
            'Model input size is read (for auto-rescale).',
            extra={
                'input_size': {
                    'width': 1200,
                    'height':
                    1200  # input shape is fixed for Faster with NasNet encoder
                }
            })

        self.class_title_to_idx = self.train_config['mapping']
        self.train_classes = sly.FigClasses(self.train_config['classes'])
        logger.info('Read model internal class mapping',
                    extra={'class_mapping': self.class_title_to_idx})
        logger.info('Read model out classes',
                    extra={'classes': self.train_classes.py_container})

        out_class_mapping = {
            x: self.class_title_to_idx[x]
            for x in (x['title'] for x in self.train_classes)
        }
        self.inv_mapping = inverse_mapping(out_class_mapping)
Exemplo n.º 3
0
    def _check_prev_model_config(self):
        prev_model_dir = self.helper.paths.model_dir
        prev_config_rw = TrainConfigRW(prev_model_dir)
        if not prev_config_rw.train_config_exists:
            raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.')
        prev_config = prev_config_rw.load()

        old_class_mapping = prev_config.get('class_title_to_idx', {})
        if self.class_title_to_idx != old_class_mapping:
            raise RuntimeError('Unable to continue training, class mapping is inconsistent with previous model.')
Exemplo n.º 4
0
        def dump_model(saver, sess, is_best, opt_data):
            out_dir = self.helper.checkpoints_saver.get_dir_to_write()
            TrainConfigRW(out_dir).save(self.out_config)
            model_fpath = os.path.join(out_dir, 'model_weights', 'model.ckpt')
            saver.save(sess, model_fpath)

            self.helper.checkpoints_saver.saved(is_best, opt_data)
Exemplo n.º 5
0
def main():
    args = parse_args()
    with open(args.in_file) as f:
        lines = f.readlines()
    lines = [ln for ln in (line.strip() for line in lines) if ln]

    out_classes = sly.FigClasses()
    for x in construct_detection_classes(lines):
        out_classes.add(x)
    cls_mapping = {x: i for i, x in enumerate(lines)}
    res_cfg = {
        'settings': {},
        'out_classes': out_classes.py_container,
        'class_title_to_idx': cls_mapping,
    }

    saver = TrainConfigRW(args.out_dir)
    saver.save(res_cfg)
    print('Done: {} -> {}'.format(args.in_file, saver.train_config_fpath))
Exemplo n.º 6
0
    def _load_train_config(self):
        train_config_rw = TrainConfigRW(self.helper.paths.model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError(
                'Unable to run inference, config from training wasn\'t found.')
        self.train_config = train_config_rw.load()

        self.class_title_to_idx = self.train_config['mapping']
        self.train_classes = sly.FigClasses(self.train_config['classes'])
        logger.info('Read model internal class mapping',
                    extra={'class_mapping': self.class_title_to_idx})
        logger.info('Read model out classes',
                    extra={'classes': self.train_classes.py_container})

        out_class_mapping = {
            x: self.class_title_to_idx[x]
            for x in (x['title'] for x in self.train_classes)
        }
        self.inv_mapping = inverse_mapping(out_class_mapping)
Exemplo n.º 7
0
    def _load_train_config(self):
        model_dir = self.helper.paths.model_dir
        train_config_rw = TrainConfigRW(model_dir)
        if not train_config_rw.train_config_exists:
            raise RuntimeError(
                'Unable to run inference, config from training wasn\'t found.')
        train_config = train_config_rw.load()

        self.train_classes = sly.FigClasses(train_config['out_classes'])
        tr_class_mapping = train_config['class_title_to_idx']

        # create
        rev_mapping = {v: k for k, v in tr_class_mapping.items()}
        self.train_names = [rev_mapping[i]
                            for i in range(len(rev_mapping))]  # ordered

        logger.info('Read model internal class mapping',
                    extra={'class_mapping': tr_class_mapping})
        logger.info('Read model out classes',
                    extra={'classes': self.train_classes.py_container})
Exemplo n.º 8
0
 def _create_checkpoints_dir(self):
     for epoch in range(self.config['epochs']):
         ckpt_dir = os.path.join(self.helper.paths.results_dir, '{:08}'.format(epoch))
         sly.mkdir(ckpt_dir)
         save_config(self.yolo_config, os.path.join(ckpt_dir, 'model.cfg'))
         TrainConfigRW(ckpt_dir).save(self.out_config)
Exemplo n.º 9
0
 def _dump_model(self, is_best, opt_data):
     out_dir = self.helper.checkpoints_saver.get_dir_to_write()
     TrainConfigRW(out_dir).save(self.out_config)
     WeightsRW(out_dir).save(self.model)
     self.helper.checkpoints_saver.saved(is_best, opt_data)