Exemplo n.º 1
0
def build_stages(models_info, preprocessors_config, launcher, model_args):
    def merge_preprocessing(model_specific, common_preprocessing):
        if model_specific:
            model_specific.extend(common_preprocessing)
            return model_specific
        return common_preprocessing

    required_stages = ['pnet']
    stages_mapping = OrderedDict([('pnet', {
        'caffe': CaffeProposalStage,
        'dlsdk': DLSDKProposalStage,
        'dummy': DummyProposalStage
    }), ('rnet', {
        'caffe': CaffeRefineStage,
        'dlsdk': DLSDKRefineStage
    }), ('onet', {
        'caffe': CaffeOutputStage,
        'dlsdk': DLSDKOutputStage
    })])
    framework = launcher.config['framework']
    stages = []
    for stage_name, stage_classes in stages_mapping.items():
        if stage_name not in models_info:
            if stage_name not in required_stages:
                continue
            else:
                raise ConfigError(
                    '{} required for evaluation'.format(stage_name))
        model_config = models_info[stage_name]
        if 'predictions' in model_config and not model_config.get(
                'store_predictions', False):
            stage_framework = 'dummy'
        else:
            stage_framework = framework
        if not contains_any(model_config, ['model', 'caffe_model'
                                           ]) and stage_framework != 'dummy':
            if model_args:
                model_config['model'] = model_args[
                    len(stages) if len(model_args) > 1 else 0]
        stage = stage_classes.get(stage_framework)
        if not stage_classes:
            raise ConfigError('{} stage does not support {} framework'.format(
                stage_name, stage_framework))
        stage_preprocess = merge_preprocessing(
            models_info[stage_name].get('preprocessing', []),
            preprocessors_config)
        preprocessor = PreprocessingExecutor(stage_preprocess)
        stages.append(stage(models_info[stage_name], preprocessor, launcher))

    if not stages:
        raise ConfigError(
            'please provide information about MTCNN pipeline stages')

    return stages
Exemplo n.º 2
0
def test_contains_any():
    assert contains_any([1, 2, 3], [1])
    assert contains_any([1, 2, 3], [4, 5, 2])
    assert not contains_any([1, 2, 3], [4, 5])