Пример #1
0
def test_explainer(n_explainer_runs, at_defaults, rf_classifier, explainer,
                   test_instance_idx):
    """
    Convergence test on Adult and Iris datasets.
    """

    # fixture returns a fitted AnchorTabular explainer
    X_test, explainer, predict_fn, predict_type = explainer
    if predict_type == 'proba':
        instance_label = np.argmax(predict_fn(
            X_test[test_instance_idx, :].reshape(1, -1)),
                                   axis=1)
    else:
        instance_label = predict_fn(X_test[test_instance_idx, :].reshape(
            1, -1))[0]

    explainer.instance_label = instance_label

    explain_defaults = at_defaults
    threshold = explain_defaults['desired_confidence']
    n_covered_ex = explain_defaults['n_covered_ex']

    for _ in range(n_explainer_runs):
        explanation = explainer.explain(X_test[test_instance_idx],
                                        threshold=threshold,
                                        **explain_defaults)
        assert explainer.instance_label == instance_label
        assert explanation.precision >= threshold
        assert explanation.coverage >= 0.05
        assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys()
        assert explanation.data.keys() == DEFAULT_DATA_ANCHOR.keys()

    sampler = explainer.samplers[0]
    assert sampler.instance_label == instance_label
    assert sampler.n_covered_ex == n_covered_ex
Пример #2
0
def test_explainer(n_explainer_runs, at_defaults, rf_classifier, explainer,
                   test_instance_idx, caplog):
    """
    Convergence test on Adult and Iris datasets.
    """

    # fixture returns a fitted AnchorTabular explainer
    X_test, explainer, predict_fn, predict_type = explainer
    if predict_type == 'proba':
        instance_label = np.argmax(predict_fn(
            X_test[test_instance_idx, :].reshape(1, -1)),
                                   axis=1)
    else:
        instance_label = predict_fn(X_test[test_instance_idx, :].reshape(
            1, -1))[0]

    explainer.instance_label = instance_label

    explain_defaults = at_defaults
    threshold = explain_defaults['desired_confidence']
    n_covered_ex = explain_defaults['n_covered_ex']

    run_precisions = []
    for _ in range(n_explainer_runs):
        explanation = explainer.explain(X_test[test_instance_idx],
                                        threshold=threshold,
                                        **explain_defaults)
        assert explainer.instance_label == instance_label
        if not "Could not find" in caplog.text:
            assert explanation.precision >= threshold
        assert explanation.coverage >= 0.01
        assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys()
        assert explanation.data.keys() == DEFAULT_DATA_ANCHOR.keys()
        run_precisions.append(explanation.precision)

    # check that 80% of runs returned a valid anchor
    assert ((np.asarray(run_precisions) >
             threshold).sum()) / n_explainer_runs >= 0.80

    sampler = explainer.samplers[0]
    assert sampler.instance_label == instance_label
    assert sampler.n_covered_ex == n_covered_ex
Пример #3
0
def test_anchor_text(lr_classifier, text, n_punctuation_marks, n_unique_words,
                     predict_type, anchor, use_similarity_proba, use_unk,
                     threshold):
    # test parameters
    num_samples = 100
    sample_proba = .5
    top_n = 500
    temperature = 1.
    n_covered_ex = 5  # number of examples where the anchor applies to be returned

    # fit and initialise predictor
    clf, preprocessor = lr_classifier
    predictor = predict_fcn(predict_type, clf, preproc=preprocessor)

    # test explainer initialization
    explainer = AnchorText(nlp, predictor)
    assert explainer.predictor(['book']).shape == (1, )

    # setup explainer
    perturb_opts = {
        'use_similarity_proba': use_similarity_proba,
        'sample_proba': sample_proba,
        'temperature': temperature,
    }
    explainer.n_covered_ex = n_covered_ex
    explainer.set_words_and_pos(text)
    explainer.set_sampler_perturbation(use_unk, perturb_opts, top_n)
    explainer.set_data_type(use_unk)
    if predict_type == 'proba':
        label = np.argmax(predictor([text])[0])
    elif predict_type == 'class':
        label = predictor([text])[0]
    explainer.instance_label = label

    assert isinstance(explainer.dtype, str)
    assert len(explainer.punctuation) == n_punctuation_marks
    assert len(explainer.words) == len(explainer.positions)

    # test sampler
    cov_true, cov_false, labels, data, coverage, _ = explainer.sampler(
        (0, anchor), num_samples)
    if not anchor:
        assert coverage == -1
    if use_similarity_proba and len(
            anchor
    ) > 0:  # check that words in present are in the proposed anchor
        assert len(anchor) * data.shape[0] == data[:, anchor].sum()

    if use_unk:
        # get list of unique words
        all_words = explainer.words
        # unique words = words in text + UNK
        assert len(np.unique(all_words)) == n_unique_words

    # test explanation
    explanation = explainer.explain(
        text,
        use_unk=use_unk,
        threshold=threshold,
        use_similarity_proba=use_similarity_proba,
    )
    assert explanation.precision >= threshold
    assert explanation.raw['prediction'].item() == label
    assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys()
    assert explanation.data.keys() == DEFAULT_DATA_ANCHOR.keys()

    # check if sampled sentences are not cut short
    keys = ['covered_true', 'covered_false']
    for i in range(len(explanation.raw['feature'])):
        example_dict = explanation.raw['examples'][i]
        for k in keys:
            for example in example_dict[k]:
                # check that we have perturbed the sentences
                if use_unk:
                    assert 'UNK' in example or example.replace(
                        ' ', '') == text.replace(' ', '')
                else:
                    assert 'UNK' not in example
                assert example[-1] in ['.', 'K']
def test_anchor_image(conv_net):

    segmentation_fn = 'slic'
    segmentation_kwargs = {'n_segments': 10, 'compactness': 10, 'sigma': .5}
    image_shape = (28, 28, 1)
    p_sample = 0.5  # probability of perturbing a superpixel
    num_samples = 10
    # img scaling settings
    scaling_offset = 260
    min_val = 0
    max_val = 255
    eps = 0.0001  # tolerance for tensor comparisons
    n_covered_ex = 3  # nb of examples where the anchor applies that are saved

    # define and train model
    clf = conv_net
    predict_fn = lambda x: clf.predict(x)

    explainer = AnchorImage(
        predict_fn,
        image_shape,
        segmentation_fn=segmentation_fn,
        segmentation_kwargs=segmentation_kwargs,
    )
    # test explainer initialization
    assert explainer.predictor(np.zeros((1,) + image_shape)).shape == (1,)
    assert explainer.custom_segmentation == False

    # test sampling and segmentation functions
    image = x_train[0]
    explainer.instance_label = predict_fn(image[np.newaxis, ...])[0]
    explainer.image = image
    explainer.n_covered_ex = n_covered_ex
    explainer.p_sample = p_sample
    segments = explainer.generate_superpixels(image)
    explainer.segments = segments
    image_preproc = explainer._preprocess_img(image)
    explainer.segment_labels = list(np.unique(segments))
    superpixels_mask = explainer._choose_superpixels(num_samples=num_samples)

    # grayscale image should be replicated across channel dim before segmentation
    assert image_preproc.shape[-1] == 3
    for channel in range(image_preproc.shape[-1]):
        assert (image.squeeze() - image_preproc[..., channel] <= eps).all() == True
    # check superpixels mask
    assert superpixels_mask.shape[0] == num_samples
    assert superpixels_mask.shape[1] == len(list(np.unique(segments)))
    assert superpixels_mask.sum(axis=1).any() <= segmentation_kwargs['n_segments']
    assert superpixels_mask.any() <= 1

    cov_true, cov_false, labels, data, coverage, _ = explainer.sampler((0, ()), num_samples)
    assert data.shape[0] == labels.shape[0]
    assert data.shape[1] == len(np.unique(segments))
    assert coverage == -1

    # test explanation
    threshold = .95
    explanation = explainer.explain(image, threshold=threshold)

    if explanation.raw['feature']:
        assert explanation.raw['examples'][-1]['covered_true'].shape[0] <= explainer.n_covered_ex
        assert explanation.raw['examples'][-1]['covered_false'].shape[0] <= explainer.n_covered_ex
    else:
        assert not explanation.raw['examples']
    assert explanation.anchor.shape == image_shape
    assert explanation.precision >= threshold
    assert len(np.unique(explanation.segments)) == len(np.unique(segments))
    assert explanation.meta.keys() == DEFAULT_META_ANCHOR.keys()
    assert explanation.data.keys() == DEFAULT_DATA_ANCHOR_IMG.keys()

    # test scaling
    fake_img = np.random.random(size=image_shape) + scaling_offset
    scaled_img = explainer._scale(fake_img, scale=(min_val, max_val))
    assert (scaled_img <= max_val).all()
    assert (scaled_img >= min_val).all()