def test_lowercase_tokens(self):
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        sentence = 'It is a great movie but it is also somewhat bad .'
        counterfactuals = [
            'It is an ok movie but its also somewhat bad .',
            'It is a terrible movie but it is also somewhat bad .',
            'It is a good movie but it is also somewhat bad .',
            'It was a good movie but it is also somewhat bad .',
            'It was a great film but it is also somewhat bad .',
            'It was a great show but it is bad also somewhat bad .',
            'It was the great movie but it is also somewhat bad .',
            'It was a movie but is somewhat bad .',
            'It was a movie and also it is somewhat bad .',
            'It was a movie but also it is very bad .',
            'It was a great but also it is bad .',
            'There is a good movie but also is somewhat bad .',
            'is a great movie but also it is somewhat bad .',
            'is a great movie but also it is somewhat .',
            'is a great movie also it is somewhat bad .',
            'is a great also it is somewhat .'
        ]

        explanation_lowercase = lemon.explain(sentence,
                                              counterfactuals,
                                              _predict_fn,
                                              class_to_explain=1,
                                              lowercase_tokens=True,
                                              return_model=True)

        # Check that the number of model coefficients is equal to the number of
        # unique tokens in the original sentence.
        tokens = [token.lower() for token in sentence.split()]
        unique_tokens = set(tokens)
        self.assertLen(explanation_lowercase.model.coef_, len(unique_tokens))

        # Check that the importance value for 'It' and 'it' are the same.
        self.assertEqual(explanation_lowercase.feature_importance[0],
                         explanation_lowercase.feature_importance[6])

        explanation_not_lowercase = lemon.explain(sentence,
                                                  counterfactuals,
                                                  _predict_fn,
                                                  class_to_explain=1,
                                                  lowercase_tokens=False,
                                                  return_model=True)

        # Check that the number of model coefficients is equal to the number of
        # unique tokens in the original sentence.
        tokens = sentence.split()
        unique_tokens = set(tokens)
        self.assertLen(explanation_not_lowercase.model.coef_,
                       len(unique_tokens))

        # Check that the importance value for 'It' and 'it' are not the same.
        self.assertNotEqual(explanation_not_lowercase.feature_importance[0],
                            explanation_not_lowercase.feature_importance[6])
    def test_explain_returns_explanation_with_intercept(self):
        """Tests if the explanation contains an intercept value."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lemon.explain('Test sentence', ['Test counterfactual'],
                                    _predict_fn,
                                    class_to_explain=1)
        self.assertNotEqual(explanation.intercept, 0.)
    def test_explain_returns_explanation_with_prediction(self):
        """Tests if the explanation contains a prediction."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        explanation = lemon.explain('Test sentence', ['Test counterfactual'],
                                    _predict_fn,
                                    class_to_explain=1,
                                    return_prediction=True)
        self.assertIsNotNone(explanation.prediction)
    def test_duplicate_tokens(self):
        """Checks the explanation for a sentence with duplicate tokens."""
        def _predict_fn(sentences):
            return np.random.uniform(0., 1., [len(list(sentences)), 2])

        sentence = 'it is a great movie but it is also somewhat bad .'
        counterfactuals = [
            'it is an ok movie but its also somewhat bad .',
            'it is a terrible movie but it is also somewhat bad .',
            'it is a good movie but it is also somewhat bad .',
            'it was a good movie but it also somewhat bad .',
            'it was a great film but it is also somewhat bad .',
            'it was a great show but it is bad also somewhat bad .',
            'it was the great movie but it is also somewhat bad .',
            'it was a movie but is somewhat bad .',
            'it was a movie and also it is somewhat bad .',
            'it was a movie but also it is very bad .',
            'it was a great but also it is bad .',
            'There is a good movie but also is somewhat bad .',
            'is a great movie but also it somewhat bad .',
            'is a great movie but also it is somewhat .',
            'is a great movie also it is somewhat bad .',
            'is a great also it is somewhat .'
        ]
        explanation = lemon.explain(sentence,
                                    counterfactuals,
                                    _predict_fn,
                                    class_to_explain=1,
                                    return_model=True)

        # Check that the number of model coefficients is equal to the number of
        # unique tokens in the original sentence.
        tokens = sentence.split()
        unique_tokens = set(tokens)
        self.assertLen(explanation.model.coef_, len(unique_tokens))

        # Check that the importance value for 'it' and 'it' are the same.
        self.assertEqual(explanation.feature_importance[0],
                         explanation.feature_importance[6])

        # Check that the importance value for 'is' and 'is' are the same.
        self.assertEqual(explanation.feature_importance[1],
                         explanation.feature_importance[7])

        print(explanation.feature_importance)
    def test_explain(self, sentence, counterfactuals, positive_token,
                     negative_token, num_classes, class_to_explain):
        """Tests explaining text classifiers with various output dimensions."""
        def _predict_fn(sentences):
            """Mock prediction function."""
            predictions = []
            np.random.seed(0)
            for sentence in sentences:
                probs = np.random.uniform(0., 1., num_classes)
                # To check if LEMON finds the right positive/negative correlations.
                if negative_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] - 1.
                if positive_token in sentence:
                    probs[class_to_explain] = probs[class_to_explain] + 1.
                predictions.append(probs)

            predictions = np.stack(predictions, axis=0)
            if num_classes == 1:
                return special.expit(predictions)
            else:
                return special.softmax(predictions, axis=-1)

        explanation = lemon.explain(sentence,
                                    counterfactuals,
                                    _predict_fn,
                                    class_to_explain,
                                    tokenizer=str.split)

        self.assertLen(explanation.feature_importance, len(sentence.split()))

        # The positive word should have the highest attribution score.
        positive_token_idx = sentence.split().index(positive_token)
        self.assertEqual(positive_token_idx,
                         np.argmax(explanation.feature_importance))

        # The negative word should have the lowest attribution score.
        negative_token_idx = sentence.split().index(negative_token)
        self.assertEqual(negative_token_idx,
                         np.argmin(explanation.feature_importance))