def test_infer_with_additional_outputs(self, models_dir):
        dlsdk_test_model = get_dlsdk_test_model(models_dir,
                                                {'outputs': ['fc1', 'fc2']})
        outputs = list(dlsdk_test_model.network.outputs.keys())

        assert contains_all(outputs, ['fc1', 'fc2', 'fc3'])
        assert dlsdk_test_model.output_blob == 'fc3'
    def test_dataset_validation_scheme(self):
        dataset_validation_scheme = Dataset.validation_scheme()
        assert isinstance(dataset_validation_scheme, list)
        assert len(dataset_validation_scheme) == 1
        dataset_params = [key for key in Dataset.parameters() if not key.startswith('_')]
        assert isinstance(dataset_validation_scheme[0], dict)
        assert contains_all(dataset_validation_scheme[0], dataset_params)
        assert len(dataset_validation_scheme[0]) == len(dataset_params)
        assert isinstance(dataset_validation_scheme[0]['name'], StringField)
        assert isinstance(dataset_validation_scheme[0]['annotation'], PathField)
        assert isinstance(dataset_validation_scheme[0]['data_source'], PathField)
        assert isinstance(dataset_validation_scheme[0]['dataset_meta'], PathField)
        assert isinstance(dataset_validation_scheme[0]['subsample_size'], BaseField)
        assert isinstance(dataset_validation_scheme[0]['shuffle'], BoolField)
        assert isinstance(dataset_validation_scheme[0]['subsample_seed'], NumberField)
        assert isinstance(dataset_validation_scheme[0]['analyze_dataset'], BoolField)
        assert isinstance(dataset_validation_scheme[0]['segmentation_masks_source'], PathField)
        assert isinstance(dataset_validation_scheme[0]['additional_data_source'], PathField)
        assert isinstance(dataset_validation_scheme[0]['batch'], NumberField)

        assert dataset_validation_scheme[0]['reader'] == BaseReader
        assert dataset_validation_scheme[0]['preprocessing'] == Preprocessor
        assert dataset_validation_scheme[0]['postprocessing'] == Postprocessor
        assert dataset_validation_scheme[0]['metrics'] == Metric
        assert dataset_validation_scheme[0]['annotation_conversion'] == BaseFormatConverter
 def test_preprocessing_validation_scheme(self):
     preprocessing_validation_scheme = Preprocessor.validation_scheme()
     assert isinstance(preprocessing_validation_scheme, list)
     assert len(preprocessing_validation_scheme) == len(Preprocessor.providers)
     auto_resize_scheme = Preprocessor.validation_scheme('auto_resize')
     assert isinstance(auto_resize_scheme, dict)
     assert contains_all(auto_resize_scheme, ['type'])
     assert isinstance(auto_resize_scheme['type'], StringField)
 def test_postprocessing_validation_scheme(self):
     postprocessing_validation_scheme = Postprocessor.validation_scheme()
     assert isinstance(postprocessing_validation_scheme, list)
     assert len(postprocessing_validation_scheme) == len(Postprocessor.providers)
     resize_pred_boxes_scheme = Postprocessor.validation_scheme('resize_prediction_boxes')
     assert isinstance(resize_pred_boxes_scheme, dict)
     assert contains_all(resize_pred_boxes_scheme, ['type', 'rescale'])
     assert isinstance(resize_pred_boxes_scheme['type'], StringField)
     assert isinstance(resize_pred_boxes_scheme['rescale'], BoolField)
 def test_metric_validation_scheme(self):
     metrics_full_validation_scheme = Metric.validation_scheme()
     assert isinstance(metrics_full_validation_scheme, list)
     assert len(metrics_full_validation_scheme) == len(Metric.providers)
     accuracy_validation_scheme = Metric.validation_scheme('accuracy')
     assert isinstance(accuracy_validation_scheme, dict)
     assert contains_all(accuracy_validation_scheme, ['type', 'top_k'])
     assert isinstance(accuracy_validation_scheme['type'], StringField)
     assert isinstance(accuracy_validation_scheme['top_k'], NumberField)
 def test_common_validation_scheme(self):
     validation_scheme = ModelEvaluator.validation_scheme()
     assert isinstance(validation_scheme, dict)
     assert len(validation_scheme) == 1
     assert 'models' in validation_scheme
     assert len(validation_scheme['models']) == 1
     assert contains_all(validation_scheme['models'][0], ['name', 'launchers', 'datasets'])
     assert isinstance(validation_scheme['models'][0]['name'], StringField)
     model_validation_scheme = validation_scheme['models'][0]
     assert model_validation_scheme['launchers'].__name__ == Launcher.__name__
     assert model_validation_scheme['datasets'].__name__ == Dataset.__name__
     assert isinstance(model_validation_scheme['launchers'].validation_scheme(), list)
     assert isinstance(model_validation_scheme['datasets'].validation_scheme(), list)
Ejemplo n.º 7
0
def test_contains_all():
    assert contains_all([1, 2, 3], [1, 2])
    assert contains_all([1, 2, 3], [1, 2], [3])
    assert not contains_all([1, 2, 3], [1, 5])
 def test_annotation_conversion_validation_scheme(self):
     converter_validation_scheme = BaseFormatConverter.validation_scheme()
     assert isinstance(converter_validation_scheme, dict)
     assert len(converter_validation_scheme) == len(BaseFormatConverter.providers)
     assert contains_all(converter_validation_scheme, BaseFormatConverter.providers)
     assert set(converter_validation_scheme['imagenet']) == set(BaseFormatConverter.validation_scheme('imagenet'))
 def test_adapter_validation_scheme(self):
     adapter_full_validation_scheme = Adapter.validation_scheme()
     assert isinstance(adapter_full_validation_scheme, dict)
     assert len(adapter_full_validation_scheme) ==  len(Adapter.providers)
     assert contains_all(adapter_full_validation_scheme, Adapter.providers)
     assert set(adapter_full_validation_scheme['classification']) == set(Adapter.validation_scheme('classification'))