def _determine_model_configuration(self): self.model_configuration = ObjectDetectionTrainer._determine_architecture_model_configuration( sly.TaskPaths.MODEL_CONFIG_PATH) # Check for possible model_configuration field in old-style config. If exists, make sure it is consistent with # the actual model config and clear model_configuration from training config before writing new-style model # config. training_model_configuration = self.config.get('model_configuration', None) if training_model_configuration is not None: if training_model_configuration != self.model_configuration: error_msg = ( 'Unable to start training. model_confguration in the training config is not consistent with ' + 'selected model architecture. Make sure you have selected the right model plugin and remove ' + 'model_confguration from the training config as it is not required anymore.' ) logger.critical(error_msg, extra={ 'training_model_configuration': self.config['model_configuration'], 'model_configuration': self.model_configuration }) raise RuntimeError(error_msg) del self.config['model_configuration']
def infer_training_class_to_idx_map(weights_init_type, in_project_class_to_idx, model_config_fpath, class_to_idx_config_key, special_class_ids=None): if weights_init_type == TRANSFER_LEARNING: logger.info('Transfer learning mode, using a class mapping created from scratch.') class_title_to_idx = in_project_class_to_idx elif weights_init_type == CONTINUE_TRAINING: logger.info('Continued training mode, reusing the existing class mapping from the model.') class_title_to_idx = read_validate_model_class_to_idx_map( model_config_fpath=model_config_fpath, in_project_classes_set=set(in_project_class_to_idx.keys()), class_to_idx_config_key=class_to_idx_config_key) else: raise RuntimeError('Unknown weights init type: {}'.format(weights_init_type)) if special_class_ids is not None: for class_title, requested_class_id in special_class_ids.items(): effective_class_id = class_title_to_idx[class_title] if requested_class_id != effective_class_id: error_msg = ('Unable to start training. Effective integer id for class {} does not match the ' + 'requested value in the training config ({} vs {}).'.format( class_title, effective_class_id, requested_class_id)) logger.critical(error_msg, extra={'class_title_to_idx': class_title_to_idx, 'special_class_ids': special_class_ids}) raise RuntimeError(error_msg) return class_title_to_idx
def create_task(task_msg, docker_api): task_cls = _task_class_mapping.get(task_msg['task_type'], None) if task_cls is None: logger.critical('unknown task type', extra={'task_msg': task_msg}) raise RuntimeError('unknown task type') task_obj = task_cls(task_msg) if issubclass(task_cls, TaskDockerized): task_obj.docker_api = docker_api return task_obj
def read_validate_model_class_to_idx_map(model_config_fpath, in_project_classes_set, class_to_idx_config_key): """Reads class id --> int index mapping from the model config; checks that the set of classes matches the input.""" if not fs.file_exists(model_config_fpath): raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.') with open(model_config_fpath) as fin: model_config = json.load(fin) model_class_mapping = model_config.get(class_to_idx_config_key, None) if model_class_mapping is None: raise RuntimeError('Unable to continue_training, model does not have class mapping information.') model_classes_set = set(model_class_mapping.keys()) if model_classes_set != in_project_classes_set: error_message_text = 'Unable to continue_training, sets of classes for model and dataset do not match.' logger.critical( error_message_text, extra={'model_classes': model_classes_set, 'dataset_classes': in_project_classes_set}) raise RuntimeError(error_message_text) return model_class_mapping.copy()