def _load_train_config(self): train_config_rw = TrainConfigRW(self.helper.paths.model_dir) if not train_config_rw.train_config_exists: raise RuntimeError( 'Unable to run inference, config from training wasn\'t found.') self.train_config = train_config_rw.load() src_size = self.train_config['settings']['input_size'] self.input_size_wh = (src_size['width'], src_size['height']) logger.info('Model input size is read (for auto-rescale).', extra={ 'input_size': { 'width': self.input_size_wh[0], 'height': self.input_size_wh[1] } }) self.class_title_to_idx = self.train_config['class_title_to_idx'] self.train_classes = sly.FigClasses(self.train_config['out_classes']) logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container}) self.out_class_mapping = { x: self.class_title_to_idx[x] for x in (x['title'] for x in self.train_classes) }
def _check_prev_model_config(self): prev_model_dir = self.helper.paths.model_dir prev_config_rw = TrainConfigRW(prev_model_dir) if not prev_config_rw.train_config_exists: raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.') prev_config = prev_config_rw.load() old_class_mapping = prev_config.get('class_title_to_idx', {}) if self.class_title_to_idx != old_class_mapping: raise RuntimeError('Unable to continue training, class mapping is inconsistent with previous model.')
def _dump_model(self, is_best, opt_data): out_dir = self.helper.checkpoints_saver.get_dir_to_write() TrainConfigRW(out_dir).save(self.out_config) model_fpath = os.path.join(out_dir, 'model.ckpt') self.saver.save(self.sess, model_fpath) self.helper.checkpoints_saver.saved(is_best, opt_data)
def _load_train_config(self): model_dir = sly.TaskPaths(determine_in_project=False).model_dir train_config_rw = TrainConfigRW(model_dir) if not train_config_rw.train_config_exists: raise RuntimeError('Unable to run inference, config from training wasn\'t found.') self.train_config = train_config_rw.load() src_size = 713 # @TODO: fixed value self.input_size_wh = (src_size, src_size) logger.info('Model input size is read (for auto-rescale).', extra={'input_size': { 'width': self.input_size_wh[0], 'height': self.input_size_wh[1] }}) self.class_title_to_idx = self.train_config['class_title_to_idx'] self.train_classes = sly.FigClasses(self.train_config['out_classes']) logger.info('Read model internal class mapping', extra={'class_mapping': self.class_title_to_idx}) logger.info('Read model out classes', extra={'classes': self.train_classes.py_container}) self.out_class_mapping = {x: self.class_title_to_idx[x] for x in (x['title'] for x in self.train_classes)}