def _load_train_config(self): model_dir = sly.TaskPaths(determine_in_project=False).model_dir train_config_rw = TrainConfigRW(model_dir) if not train_config_rw.train_config_exists: raise RuntimeError( 'Unable to run inference, config from training wasn\'t found.') self.train_config = train_config_rw.load() src_size = self.train_config['settings']['input_size'] self.input_size_wh = (src_size['width'], src_size['height']) logger.info('Model input size is read (for auto-rescale).', extra={ 'input_size': { 'width': self.input_size_wh[0], 'height': self.input_size_wh[1] } }) self.class_title_to_idx = self.train_config['class_title_to_idx'] self.train_classes = sly.FigClasses(self.train_config['out_classes']) logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container}) self.out_class_mapping = { x: self.class_title_to_idx[x] for x in (x['title'] for x in self.train_classes) }
def _load_train_config(self): train_config_rw = TrainConfigRW(self.helper.paths.model_dir) if not train_config_rw.train_config_exists: raise RuntimeError( 'Unable to run inference, config from training wasn\'t found.') self.train_config = train_config_rw.load() logger.info( 'Model input size is read (for auto-rescale).', extra={ 'input_size': { 'width': 1200, 'height': 1200 # input shape is fixed for Faster with NasNet encoder } }) self.class_title_to_idx = self.train_config['mapping'] self.train_classes = sly.FigClasses(self.train_config['classes']) logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container}) out_class_mapping = { x: self.class_title_to_idx[x] for x in (x['title'] for x in self.train_classes) } self.inv_mapping = inverse_mapping(out_class_mapping)
def _check_prev_model_config(self): prev_model_dir = self.helper.paths.model_dir prev_config_rw = TrainConfigRW(prev_model_dir) if not prev_config_rw.train_config_exists: raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.') prev_config = prev_config_rw.load() old_class_mapping = prev_config.get('class_title_to_idx', {}) if self.class_title_to_idx != old_class_mapping: raise RuntimeError('Unable to continue training, class mapping is inconsistent with previous model.')
def dump_model(saver, sess, is_best, opt_data): out_dir = self.helper.checkpoints_saver.get_dir_to_write() TrainConfigRW(out_dir).save(self.out_config) model_fpath = os.path.join(out_dir, 'model_weights', 'model.ckpt') saver.save(sess, model_fpath) self.helper.checkpoints_saver.saved(is_best, opt_data)
def main(): args = parse_args() with open(args.in_file) as f: lines = f.readlines() lines = [ln for ln in (line.strip() for line in lines) if ln] out_classes = sly.FigClasses() for x in construct_detection_classes(lines): out_classes.add(x) cls_mapping = {x: i for i, x in enumerate(lines)} res_cfg = { 'settings': {}, 'out_classes': out_classes.py_container, 'class_title_to_idx': cls_mapping, } saver = TrainConfigRW(args.out_dir) saver.save(res_cfg) print('Done: {} -> {}'.format(args.in_file, saver.train_config_fpath))
def _load_train_config(self): train_config_rw = TrainConfigRW(self.helper.paths.model_dir) if not train_config_rw.train_config_exists: raise RuntimeError( 'Unable to run inference, config from training wasn\'t found.') self.train_config = train_config_rw.load() self.class_title_to_idx = self.train_config['mapping'] self.train_classes = sly.FigClasses(self.train_config['classes']) logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container}) out_class_mapping = { x: self.class_title_to_idx[x] for x in (x['title'] for x in self.train_classes) } self.inv_mapping = inverse_mapping(out_class_mapping)
def _load_train_config(self): model_dir = self.helper.paths.model_dir train_config_rw = TrainConfigRW(model_dir) if not train_config_rw.train_config_exists: raise RuntimeError( 'Unable to run inference, config from training wasn\'t found.') train_config = train_config_rw.load() self.train_classes = sly.FigClasses(train_config['out_classes']) tr_class_mapping = train_config['class_title_to_idx'] # create rev_mapping = {v: k for k, v in tr_class_mapping.items()} self.train_names = [rev_mapping[i] for i in range(len(rev_mapping))] # ordered logger.info('Read model internal class mapping', extra={'class_mapping': tr_class_mapping}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container})
def _create_checkpoints_dir(self): for epoch in range(self.config['epochs']): ckpt_dir = os.path.join(self.helper.paths.results_dir, '{:08}'.format(epoch)) sly.mkdir(ckpt_dir) save_config(self.yolo_config, os.path.join(ckpt_dir, 'model.cfg')) TrainConfigRW(ckpt_dir).save(self.out_config)
def _dump_model(self, is_best, opt_data): out_dir = self.helper.checkpoints_saver.get_dir_to_write() TrainConfigRW(out_dir).save(self.out_config) WeightsRW(out_dir).save(self.model) self.helper.checkpoints_saver.saved(is_best, opt_data)