def test_text_config(): granularity = "sentence" language = "german" text_config = TextConfig( granularity=granularity, language=language, ) expected_config = { "granularity": granularity, "language": language, } assert expected_config == text_config.get_text_config()
def test_invalid_text_config(): with pytest.raises(ValueError) as error: TextConfig( granularity="invalid", language="english", ) assert ( "Invalid granularity invalid. Please choose among ['token', 'sentence', 'paragraph']" in str(error.value)) with pytest.raises(ValueError) as error: TextConfig( granularity="token", language="invalid", ) assert "Invalid language invalid. Please choose among ['chinese'," in str( error.value)
def test_shap_with_text_config( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, ): granularity = "paragraph" language = "ukrainian" shap_config = SHAPConfig( baseline=[ [ 0.26124998927116394, 0.2824999988079071, 0.06875000149011612, ] ], num_samples=100, agg_method="mean_sq", text_config=TextConfig(granularity, language), ) expected_text_config = { "granularity": granularity, "language": language, } expected_predictor_config = { "model_name": "xgboost-model", "instance_type": "ml.c5.xlarge", "initial_instance_count": 1, } _run_test_explain( name_from_base, clarify_processor, clarify_processor_with_job_name_prefix, data_config, model_config, shap_config, None, None, expected_predictor_config, expected_text_config=expected_text_config, )
def test_shap_config(): baseline = [[ 0.26124998927116394, 0.2824999988079071, 0.06875000149011612, ]] num_samples = 100 agg_method = "mean_sq" use_logit = True seed = 123 granularity = "sentence" language = "german" model_type = "IMAGE_CLASSIFICATION" num_segments = 2 feature_extraction_method = "segmentation" segment_compactness = 10 max_objects = 4 iou_threshold = 0.5 context = 1.0 text_config = TextConfig( granularity=granularity, language=language, ) image_config = ImageConfig( model_type=model_type, num_segments=num_segments, feature_extraction_method=feature_extraction_method, segment_compactness=segment_compactness, max_objects=max_objects, iou_threshold=iou_threshold, context=context, ) shap_config = SHAPConfig( baseline=baseline, num_samples=num_samples, agg_method=agg_method, use_logit=use_logit, seed=seed, text_config=text_config, image_config=image_config, ) expected_config = { "shap": { "baseline": baseline, "num_samples": num_samples, "agg_method": agg_method, "use_logit": use_logit, "save_local_shap_values": True, "seed": seed, "text_config": { "granularity": granularity, "language": language, }, "image_config": { "model_type": model_type, "num_segments": num_segments, "feature_extraction_method": feature_extraction_method, "segment_compactness": segment_compactness, "max_objects": max_objects, "iou_threshold": iou_threshold, "context": context, }, } } assert expected_config == shap_config.get_explainability_config()