Esempio n. 1
0
    def _determine_model_configuration(self):
        self.model_configuration = ObjectDetectionTrainer._determine_architecture_model_configuration(
            sly.TaskPaths.MODEL_CONFIG_PATH)

        # Check for possible model_configuration field in old-style config. If exists, make sure it is consistent with
        # the actual model config and clear model_configuration from training config before writing new-style model
        # config.
        training_model_configuration = self.config.get('model_configuration',
                                                       None)
        if training_model_configuration is not None:
            if training_model_configuration != self.model_configuration:
                error_msg = (
                    'Unable to start training. model_confguration in the training config is not consistent with '
                    +
                    'selected model architecture. Make sure you have selected the right model plugin and remove '
                    +
                    'model_confguration from the training config as it is not required anymore.'
                )
                logger.critical(error_msg,
                                extra={
                                    'training_model_configuration':
                                    self.config['model_configuration'],
                                    'model_configuration':
                                    self.model_configuration
                                })
                raise RuntimeError(error_msg)
            del self.config['model_configuration']
Esempio n. 2
0
def infer_training_class_to_idx_map(weights_init_type, in_project_class_to_idx, model_config_fpath,
                                    class_to_idx_config_key, special_class_ids=None):
    if weights_init_type == TRANSFER_LEARNING:
        logger.info('Transfer learning mode, using a class mapping created from scratch.')
        class_title_to_idx = in_project_class_to_idx
    elif weights_init_type == CONTINUE_TRAINING:
        logger.info('Continued training mode, reusing the existing class mapping from the model.')
        class_title_to_idx = read_validate_model_class_to_idx_map(
            model_config_fpath=model_config_fpath,
            in_project_classes_set=set(in_project_class_to_idx.keys()),
            class_to_idx_config_key=class_to_idx_config_key)
    else:
        raise RuntimeError('Unknown weights init type: {}'.format(weights_init_type))

    if special_class_ids is not None:
        for class_title, requested_class_id in special_class_ids.items():
            effective_class_id = class_title_to_idx[class_title]
            if requested_class_id != effective_class_id:
                error_msg = ('Unable to start training. Effective integer id for class {} does not match the ' +
                             'requested value in the training config ({} vs {}).'.format(
                                 class_title, effective_class_id, requested_class_id))
                logger.critical(error_msg, extra={'class_title_to_idx': class_title_to_idx,
                                                  'special_class_ids': special_class_ids})
                raise RuntimeError(error_msg)
    return class_title_to_idx
Esempio n. 3
0
def create_task(task_msg, docker_api):
    task_cls = _task_class_mapping.get(task_msg['task_type'], None)
    if task_cls is None:
        logger.critical('unknown task type', extra={'task_msg': task_msg})
        raise RuntimeError('unknown task type')

    task_obj = task_cls(task_msg)
    if issubclass(task_cls, TaskDockerized):
        task_obj.docker_api = docker_api
    return task_obj
Esempio n. 4
0
def read_validate_model_class_to_idx_map(model_config_fpath, in_project_classes_set, class_to_idx_config_key):
    """Reads class id --> int index mapping from the model config; checks that the set of classes matches the input."""

    if not fs.file_exists(model_config_fpath):
        raise RuntimeError('Unable to continue_training, config for previous training wasn\'t found.')

    with open(model_config_fpath) as fin:
        model_config = json.load(fin)

    model_class_mapping = model_config.get(class_to_idx_config_key, None)
    if model_class_mapping is None:
        raise RuntimeError('Unable to continue_training, model does not have class mapping information.')
    model_classes_set = set(model_class_mapping.keys())

    if model_classes_set != in_project_classes_set:
        error_message_text = 'Unable to continue_training, sets of classes for model and dataset do not match.'
        logger.critical(
            error_message_text, extra={'model_classes': model_classes_set, 'dataset_classes': in_project_classes_set})
        raise RuntimeError(error_message_text)
    return model_class_mapping.copy()