def test_shap_config():
    baseline = [[
        0.26124998927116394,
        0.2824999988079071,
        0.06875000149011612,
    ]]
    num_samples = 100
    agg_method = "mean_sq"
    use_logit = True
    seed = 123
    shap_config = SHAPConfig(
        baseline=baseline,
        num_samples=num_samples,
        agg_method=agg_method,
        use_logit=use_logit,
        seed=seed,
    )
    expected_config = {
        "shap": {
            "baseline": baseline,
            "num_samples": num_samples,
            "agg_method": agg_method,
            "use_logit": use_logit,
            "save_local_shap_values": True,
            "seed": seed,
        }
    }
    assert expected_config == shap_config.get_explainability_config()
def test_shap_config_no_parameters():
    shap_config = SHAPConfig()
    expected_config = {
        "shap": {
            "use_logit": False,
            "save_local_shap_values": True,
        }
    }
    assert expected_config == shap_config.get_explainability_config()
def test_shap_config_no_baseline():
    num_samples = 100
    agg_method = "mean_sq"
    use_logit = True
    seed = 123
    shap_config = SHAPConfig(
        num_samples=num_samples,
        agg_method=agg_method,
        num_clusters=2,
        use_logit=use_logit,
        seed=seed,
    )
    expected_config = {
        "shap": {
            "num_samples": num_samples,
            "agg_method": agg_method,
            "num_clusters": 2,
            "use_logit": use_logit,
            "save_local_shap_values": True,
            "seed": seed,
        }
    }
    assert expected_config == shap_config.get_explainability_config()
def test_shap_config():
    baseline = [[
        0.26124998927116394,
        0.2824999988079071,
        0.06875000149011612,
    ]]
    num_samples = 100
    agg_method = "mean_sq"
    use_logit = True
    seed = 123
    granularity = "sentence"
    language = "german"
    model_type = "IMAGE_CLASSIFICATION"
    num_segments = 2
    feature_extraction_method = "segmentation"
    segment_compactness = 10
    max_objects = 4
    iou_threshold = 0.5
    context = 1.0
    text_config = TextConfig(
        granularity=granularity,
        language=language,
    )
    image_config = ImageConfig(
        model_type=model_type,
        num_segments=num_segments,
        feature_extraction_method=feature_extraction_method,
        segment_compactness=segment_compactness,
        max_objects=max_objects,
        iou_threshold=iou_threshold,
        context=context,
    )
    shap_config = SHAPConfig(
        baseline=baseline,
        num_samples=num_samples,
        agg_method=agg_method,
        use_logit=use_logit,
        seed=seed,
        text_config=text_config,
        image_config=image_config,
    )
    expected_config = {
        "shap": {
            "baseline": baseline,
            "num_samples": num_samples,
            "agg_method": agg_method,
            "use_logit": use_logit,
            "save_local_shap_values": True,
            "seed": seed,
            "text_config": {
                "granularity": granularity,
                "language": language,
            },
            "image_config": {
                "model_type": model_type,
                "num_segments": num_segments,
                "feature_extraction_method": feature_extraction_method,
                "segment_compactness": segment_compactness,
                "max_objects": max_objects,
                "iou_threshold": iou_threshold,
                "context": context,
            },
        }
    }
    assert expected_config == shap_config.get_explainability_config()