def test_infer_with_additional_outputs(self, models_dir): dlsdk_test_model = get_dlsdk_test_model(models_dir, {'outputs': ['fc1', 'fc2']}) outputs = list(dlsdk_test_model.network.outputs.keys()) assert contains_all(outputs, ['fc1', 'fc2', 'fc3']) assert dlsdk_test_model.output_blob == 'fc3'
def test_dataset_validation_scheme(self): dataset_validation_scheme = Dataset.validation_scheme() assert isinstance(dataset_validation_scheme, list) assert len(dataset_validation_scheme) == 1 dataset_params = [key for key in Dataset.parameters() if not key.startswith('_')] assert isinstance(dataset_validation_scheme[0], dict) assert contains_all(dataset_validation_scheme[0], dataset_params) assert len(dataset_validation_scheme[0]) == len(dataset_params) assert isinstance(dataset_validation_scheme[0]['name'], StringField) assert isinstance(dataset_validation_scheme[0]['annotation'], PathField) assert isinstance(dataset_validation_scheme[0]['data_source'], PathField) assert isinstance(dataset_validation_scheme[0]['dataset_meta'], PathField) assert isinstance(dataset_validation_scheme[0]['subsample_size'], BaseField) assert isinstance(dataset_validation_scheme[0]['shuffle'], BoolField) assert isinstance(dataset_validation_scheme[0]['subsample_seed'], NumberField) assert isinstance(dataset_validation_scheme[0]['analyze_dataset'], BoolField) assert isinstance(dataset_validation_scheme[0]['segmentation_masks_source'], PathField) assert isinstance(dataset_validation_scheme[0]['additional_data_source'], PathField) assert isinstance(dataset_validation_scheme[0]['batch'], NumberField) assert dataset_validation_scheme[0]['reader'] == BaseReader assert dataset_validation_scheme[0]['preprocessing'] == Preprocessor assert dataset_validation_scheme[0]['postprocessing'] == Postprocessor assert dataset_validation_scheme[0]['metrics'] == Metric assert dataset_validation_scheme[0]['annotation_conversion'] == BaseFormatConverter
def test_preprocessing_validation_scheme(self): preprocessing_validation_scheme = Preprocessor.validation_scheme() assert isinstance(preprocessing_validation_scheme, list) assert len(preprocessing_validation_scheme) == len(Preprocessor.providers) auto_resize_scheme = Preprocessor.validation_scheme('auto_resize') assert isinstance(auto_resize_scheme, dict) assert contains_all(auto_resize_scheme, ['type']) assert isinstance(auto_resize_scheme['type'], StringField)
def test_postprocessing_validation_scheme(self): postprocessing_validation_scheme = Postprocessor.validation_scheme() assert isinstance(postprocessing_validation_scheme, list) assert len(postprocessing_validation_scheme) == len(Postprocessor.providers) resize_pred_boxes_scheme = Postprocessor.validation_scheme('resize_prediction_boxes') assert isinstance(resize_pred_boxes_scheme, dict) assert contains_all(resize_pred_boxes_scheme, ['type', 'rescale']) assert isinstance(resize_pred_boxes_scheme['type'], StringField) assert isinstance(resize_pred_boxes_scheme['rescale'], BoolField)
def test_metric_validation_scheme(self): metrics_full_validation_scheme = Metric.validation_scheme() assert isinstance(metrics_full_validation_scheme, list) assert len(metrics_full_validation_scheme) == len(Metric.providers) accuracy_validation_scheme = Metric.validation_scheme('accuracy') assert isinstance(accuracy_validation_scheme, dict) assert contains_all(accuracy_validation_scheme, ['type', 'top_k']) assert isinstance(accuracy_validation_scheme['type'], StringField) assert isinstance(accuracy_validation_scheme['top_k'], NumberField)
def test_common_validation_scheme(self): validation_scheme = ModelEvaluator.validation_scheme() assert isinstance(validation_scheme, dict) assert len(validation_scheme) == 1 assert 'models' in validation_scheme assert len(validation_scheme['models']) == 1 assert contains_all(validation_scheme['models'][0], ['name', 'launchers', 'datasets']) assert isinstance(validation_scheme['models'][0]['name'], StringField) model_validation_scheme = validation_scheme['models'][0] assert model_validation_scheme['launchers'].__name__ == Launcher.__name__ assert model_validation_scheme['datasets'].__name__ == Dataset.__name__ assert isinstance(model_validation_scheme['launchers'].validation_scheme(), list) assert isinstance(model_validation_scheme['datasets'].validation_scheme(), list)
def test_contains_all(): assert contains_all([1, 2, 3], [1, 2]) assert contains_all([1, 2, 3], [1, 2], [3]) assert not contains_all([1, 2, 3], [1, 5])
def test_annotation_conversion_validation_scheme(self): converter_validation_scheme = BaseFormatConverter.validation_scheme() assert isinstance(converter_validation_scheme, dict) assert len(converter_validation_scheme) == len(BaseFormatConverter.providers) assert contains_all(converter_validation_scheme, BaseFormatConverter.providers) assert set(converter_validation_scheme['imagenet']) == set(BaseFormatConverter.validation_scheme('imagenet'))
def test_adapter_validation_scheme(self): adapter_full_validation_scheme = Adapter.validation_scheme() assert isinstance(adapter_full_validation_scheme, dict) assert len(adapter_full_validation_scheme) == len(Adapter.providers) assert contains_all(adapter_full_validation_scheme, Adapter.providers) assert set(adapter_full_validation_scheme['classification']) == set(Adapter.validation_scheme('classification'))