def test_adapter_validation_scheme(self): adapter_full_validation_scheme = Adapter.validation_scheme() assert isinstance(adapter_full_validation_scheme, dict) assert len(adapter_full_validation_scheme) == len(Adapter.providers) assert contains_all(adapter_full_validation_scheme, Adapter.providers) assert set(adapter_full_validation_scheme['classification']) == set( Adapter.validation_scheme('classification'))
def __init__(self, network_info, launcher, models_args, is_blob=None): detector = network_info.get('detector', {}) recognizer_encoder = network_info.get('recognizer_encoder', {}) recognizer_decoder = network_info.get('recognizer_decoder', {}) if 'model' not in detector: detector['model'] = models_args[0] detector['_model_is_blob'] = is_blob if 'model' not in recognizer_encoder: recognizer_encoder['model'] = models_args[1 if len(models_args) > 1 else 0] recognizer_encoder['_model_is_blob'] = is_blob if 'model' not in recognizer_decoder: recognizer_decoder['model'] = models_args[2 if len(models_args) > 2 else 0] recognizer_decoder['_model_is_blob'] = is_blob network_info.update({ 'detector': detector, 'recognizer_encoder': recognizer_encoder, 'recognizer_decoder': recognizer_decoder }) if not contains_all(network_info, ['detector', 'recognizer_encoder', 'recognizer_decoder']): raise ConfigError('network_info should contains detector, encoder and decoder fields') self.detector = create_detector(network_info['detector'], launcher) self.recognizer_encoder = create_recognizer(network_info['recognizer_encoder'], launcher, 'encoder') self.recognizer_decoder = create_recognizer(network_info['recognizer_decoder'], launcher, 'decoder') self.recognizer_decoder_inputs = network_info['recognizer_decoder_inputs'] self.recognizer_decoder_outputs = network_info['recognizer_decoder_outputs'] self.max_seq_len = int(network_info['max_seq_len']) self.adapter = create_adapter(network_info['adapter']) self.alphabet = network_info['alphabet'] self.sos_index = int(network_info['sos_index']) self.eos_index = int(network_info['eos_index'])
def test_infer_with_additional_outputs(self, models_dir): dlsdk_test_model = get_dlsdk_test_model(models_dir, {'outputs': ['fc1', 'fc2']}) outputs = list(dlsdk_test_model.network.outputs.keys()) assert contains_all(outputs, ['fc1', 'fc2', 'fc3']) assert dlsdk_test_model.output_blob == 'fc3'
def test_annotation_conversion_validation_scheme(self): converter_validation_scheme = BaseFormatConverter.validation_scheme() assert isinstance(converter_validation_scheme, dict) assert len(converter_validation_scheme) == len( BaseFormatConverter.providers) assert contains_all(converter_validation_scheme, BaseFormatConverter.providers) assert set(converter_validation_scheme['imagenet']) == set( BaseFormatConverter.validation_scheme('imagenet'))
def test_preprocessing_validation_scheme(self): preprocessing_validation_scheme = Preprocessor.validation_scheme() assert isinstance(preprocessing_validation_scheme, list) assert len(preprocessing_validation_scheme) == len( Preprocessor.providers) auto_resize_scheme = Preprocessor.validation_scheme('auto_resize') assert isinstance(auto_resize_scheme, dict) assert contains_all(auto_resize_scheme, ['type']) assert isinstance(auto_resize_scheme['type'], StringField)
def test_metric_validation_scheme(self): metrics_full_validation_scheme = Metric.validation_scheme() assert isinstance(metrics_full_validation_scheme, list) assert len(metrics_full_validation_scheme) == len(Metric.providers) accuracy_validation_scheme = Metric.validation_scheme('accuracy') assert isinstance(accuracy_validation_scheme, dict) assert contains_all(accuracy_validation_scheme, ['type', 'top_k']) assert isinstance(accuracy_validation_scheme['type'], StringField) assert isinstance(accuracy_validation_scheme['top_k'], NumberField)
def test_postprocessing_validation_scheme(self): postprocessing_validation_scheme = Postprocessor.validation_scheme() assert isinstance(postprocessing_validation_scheme, list) assert len(postprocessing_validation_scheme) == len( Postprocessor.providers) resize_pred_boxes_scheme = Postprocessor.validation_scheme( 'resize_prediction_boxes') assert isinstance(resize_pred_boxes_scheme, dict) assert contains_all(resize_pred_boxes_scheme, ['type', 'rescale']) assert isinstance(resize_pred_boxes_scheme['type'], StringField) assert isinstance(resize_pred_boxes_scheme['rescale'], BoolField)
def test_infer_with_additional_outputs(self, data_dir, models_dir): dlsdk_test_model = get_dlsdk_test_model(models_dir, {'outputs': ['fc1', 'fc2']}) result = dlsdk_test_model.predict( ['1.jpg'], [get_image(data_dir / '1.jpg', dlsdk_test_model.inputs['data'])]) outputs = list(dlsdk_test_model.network.outputs.keys()) adapter_output_blob = dlsdk_test_model.adapter.output_blob assert contains_all(outputs, ['fc1', 'fc2', 'fc3']) assert adapter_output_blob == 'fc3' assert result[0].label == 6
def __init__(self, network_info, launcher): super().__init__(network_info, launcher) if not contains_all(network_info, ['encoder', 'decoder']): raise ConfigError( 'network_info should contains encoder and decoder fields') self.num_processing_frames = network_info['decoder'].get( 'num_processing_frames', 16) self.processing_frames_buffer = [] self.encoder = create_encoder(network_info['encoder'], launcher) self.decoder = create_decoder(network_info['decoder'], launcher) self.store_encoder_predictions = network_info['encoder'].get( 'store_predictions', False) self._encoder_predictions = [] if self.store_encoder_predictions else None
def test_common_validation_scheme(self): validation_scheme = ModelEvaluator.validation_scheme() assert isinstance(validation_scheme, dict) assert len(validation_scheme) == 1 assert 'models' in validation_scheme assert len(validation_scheme['models']) == 1 assert contains_all(validation_scheme['models'][0], ['name', 'launchers', 'datasets']) assert isinstance(validation_scheme['models'][0]['name'], StringField) model_validation_scheme = validation_scheme['models'][0] assert model_validation_scheme[ 'launchers'].__name__ == Launcher.__name__ assert model_validation_scheme['datasets'].__name__ == Dataset.__name__ assert isinstance( model_validation_scheme['launchers'].validation_scheme(), list) assert isinstance( model_validation_scheme['datasets'].validation_scheme(), list)
def from_configs(cls, config): dataset_config = config['datasets'][0] dataset = Dataset(dataset_config) data_reader_config = dataset_config.get('reader', 'opencv_imread') data_source = dataset_config['data_source'] if isinstance(data_reader_config, str): reader = BaseReader.provide(data_reader_config, data_source) elif isinstance(data_reader_config, dict): reader = BaseReader.provide(data_reader_config['type'], data_source, data_reader_config) else: raise ConfigError('reader should be dict or string') preprocessing = PreprocessingExecutor(dataset_config.get('preprocessing', []), dataset.name) metrics_executor = MetricsExecutor(dataset_config['metrics'], dataset) launcher_settings = config['launchers'][0] supported_frameworks = ['dlsdk'] if not launcher_settings['framework'] in supported_frameworks: raise ConfigError('{} framework not supported'.format(launcher_settings['framework'])) launcher = create_launcher(launcher_settings, delayed_model_loading=True) network_info = config.get('network_info', {}) colorization_network = network_info.get('colorization_network', {}) verification_network = network_info.get('verification_network', {}) model_args = config.get('_models', []) models_is_blob = config.get('_model_is_blob') if 'model' not in colorization_network and model_args: colorization_network['model'] = model_args[0] colorization_network['_model_is_blob'] = models_is_blob if 'model' not in verification_network and model_args: verification_network['model'] = model_args[1 if len(model_args) > 1 else 0] verification_network['_model_is_blob'] = models_is_blob network_info.update({ 'colorization_network': colorization_network, 'verification_network': verification_network }) if not contains_all(network_info, ['colorization_network', 'verification_network']): raise ConfigError('configuration for colorization_network/verification_network does not exist') test_model = ColorizationTestModel(network_info['colorization_network'], launcher) check_model = ColorizationCheckModel(network_info['verification_network'], launcher) return cls(dataset, reader, preprocessing, metrics_executor, launcher, test_model, check_model)
def __init__(self, network_info, launcher): super().__init__(network_info, launcher) if not contains_all( network_info, ['detector', 'recognizer_encoder', 'recognizer_decoder']): raise ConfigError( 'network_info should contains detector, encoder and decoder fields' ) self.detector = create_detector(network_info['detector'], launcher) self.recognizer_encoder = create_recognizer( network_info['recognizer_encoder'], launcher) self.recognizer_decoder = create_recognizer( network_info['recognizer_decoder'], launcher) self.recognizer_decoder_inputs = network_info[ 'recognizer_decoder_inputs'] self.recognizer_decoder_outputs = network_info[ 'recognizer_decoder_outputs'] self.max_seq_len = int(network_info['max_seq_len']) self.adapter = create_adapter(network_info['adapter']) self.alphabet = network_info['alphabet'] self.sos_index = int(network_info['sos_index']) self.eos_index = int(network_info['eos_index'])
def test_dataset_validation_scheme(self): dataset_validation_scheme = Dataset.validation_scheme() assert isinstance(dataset_validation_scheme, list) assert len(dataset_validation_scheme) == 1 dataset_params = [ key for key in Dataset.parameters() if not key.startswith('_') ] assert isinstance(dataset_validation_scheme[0], dict) assert contains_all(dataset_validation_scheme[0], dataset_params) assert len(dataset_validation_scheme[0]) == len(dataset_params) assert isinstance(dataset_validation_scheme[0]['name'], StringField) assert isinstance(dataset_validation_scheme[0]['annotation'], PathField) assert isinstance(dataset_validation_scheme[0]['data_source'], PathField) assert isinstance(dataset_validation_scheme[0]['dataset_meta'], PathField) assert isinstance(dataset_validation_scheme[0]['subsample_size'], BaseField) assert isinstance(dataset_validation_scheme[0]['shuffle'], BoolField) assert isinstance(dataset_validation_scheme[0]['subsample_seed'], NumberField) assert isinstance(dataset_validation_scheme[0]['analyze_dataset'], BoolField) assert isinstance( dataset_validation_scheme[0]['segmentation_masks_source'], PathField) assert isinstance( dataset_validation_scheme[0]['additional_data_source'], PathField) assert isinstance(dataset_validation_scheme[0]['batch'], NumberField) assert dataset_validation_scheme[0]['reader'] == BaseReader assert dataset_validation_scheme[0]['preprocessing'] == Preprocessor assert dataset_validation_scheme[0]['postprocessing'] == Postprocessor assert dataset_validation_scheme[0]['metrics'] == Metric assert dataset_validation_scheme[0][ 'annotation_conversion'] == BaseFormatConverter
def test_contains_all(): assert contains_all([1, 2, 3], [1, 2]) assert contains_all([1, 2, 3], [1, 2], [3]) assert not contains_all([1, 2, 3], [1, 5])