class SimpleModelTest(TestCase):
    def setUp(self):
        R = np.diag([0.1, 0.1, 0.1, 0.1, 0.1])
        Q = np.array([[0.01]])
        Q_b = np.array([[10]])

        self.model = SimpleModel(0.1, R, Q, Q_b)

    def test_accessor(self):
        self.model.mu = [1, 2, 3, 4, 5]

        self.assertEqual(1, self.model.x)
        self.assertEqual(2, self.model.y)
        self.assertEqual(3, self.model.theta)
        self.assertEqual(4, self.model.vx)
        self.assertEqual(5, self.model.vy)

    def test_update(self):
        """
        Smoke test for the predictor.
        """
        self.model.predict(np.array([1, 0]))
        self.model.predict(np.array([1, 0]))
        self.assertAlmostEqual(self.model.x, 0.01, places=4)

    def test_correct_angle(self):
        """
        Smoke test for the angle corrector.
        """
        self.model.correct_angle(1)
        self.assertAlmostEqual(self.model.theta, 0.5, places=3)

    def test_correct_beacon(self):
        """
        Smoke test for the beacon corrector.
        """
        x, y = 1, 1  # ground truth passed to the model for computing beacon pos

        def distance(bx, by):
            return np.sqrt((bx - x) ** 2 + (by - y) ** 2)

        beacons = [
            (3, 3),
            (0, 3),
            (3, 0),
        ]

        for _ in range(100):
            # we run a prediction step to increase variance
            self.model.predict((0, 0))
            for bx, by in beacons:
                self.model.correct_beacon(bx, by, distance(bx, by))

        self.assertAlmostEqual(self.model.x, x, places=2)
Beispiel #2
0
    # Flip the image horizontally for a later selfie-view display, and convert
    # the BGR image to RGB.
    image = cv.cvtColor(cv.flip(image, 1), cv.COLOR_BGR2RGB)
    # To improve performance, optionally mark the image as not writeable to
    # pass by reference.
    image.flags.writeable = False
    results = hands.process(image)

    # Draw the hand annotations on the image.
    image.flags.writeable = True
    image = cv.cvtColor(image, cv.COLOR_RGB2BGR)
    nimage = np.zeros_like(image)
    if results.multi_hand_landmarks:
      for hand_landmarks in results.multi_hand_landmarks:
        hand = parse_landmark(hand_landmarks.landmark)
        # breakpoint()
        act = model.predict(hand)[0]
        text = action_classes[act]
        # print('-'*50)
        cv.putText(image, text, (50,50), cv.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0),
                   2, cv.LINE_AA)
        # print('Action: ', act)

        mp_drawing.draw_landmarks(
            nimage, hand_landmarks, mp_hands.HAND_CONNECTIONS)
    cv.imshow('MediaPipe Hands', nimage)
    cv.imshow('raw', image)
    if cv.waitKey(5) & 0xFF == 27:
      break
cap.release()
Beispiel #3
0
def main(args):
    data = get_data(args.data_path, args.is_new_data)

    plt.clf()
    data["opinion"].apply(len).hist()
    plt.xlabel("document lenght")
    plt.ylabel("frequency")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, 'data_len_hist.png'))
    translator = str.maketrans('', '', punctuation)

    def unigrams_distinct(text):
        return len(set(text.replace('\r', ' ').replace('\n', ' ').lower().translate(translator).split()))

    plt.clf()
    data["opinion"].apply(unigrams_distinct).hist()
    plt.xlabel("distinct unigrams")
    plt.ylabel("frequency")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, 'data_unigrams_distinct.png'))

    data_train, data_val_model, data_val_interpretation, data_test = get_train_val_test_splits(data)
    data_val = pd.concat([data_val_model, data_val_interpretation])
    model = SimpleModel(data_train, data_val_model)
    if not os.path.exists(os.path.join(args.result_dir, "simple_model_fit.csv")):
        fit = model.fit()
        pd.DataFrame(fit).to_csv(os.path.join(args.result_dir, "simple_model_fit.csv"), index=False, header=True)
    else:
        fit = pd.read_csv(os.path.join(args.result_dir, "simple_model_fit.csv"))
        model.load(fit)
    print("Model validation f1 score:")
    print(model.evaluate(data_val_model))
    print("Model test f1 score:")
    print(model.evaluate(data_test))

    print("Model test confusion matrix:")
    print(confusion_matrix(data_test["outcome"].values, np.argmax(model.predict(data_test["opinion"]), axis=1)))

    explainer = LimeTextExplainer(class_names=[0, 1])
    model_predict_fn = model.predict
    e = data_test["opinion"].head(1).apply(explainer.explain_instance, classifier_fn=model_predict_fn).iloc[0]
    plt.clf()
    e.as_pyplot_figure()
    plt.tight_layout()
    plt.xlabel("LIME score")
    plt.ylabel("word")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, 'lime_example.png'))

    lime_explainer = LimeExplainer()
    print("lime on validation data:")
    if not os.path.exists(os.path.join(args.result_dir, "lime_data_val.csv")):
        lime_dataset = lime_explainer.build_explanations(model.predict, data_val_interpretation)
        lime_dataset.to_csv("../data/lime_data_val.csv", header=True, index=False)
    else:
        lime_dataset = pd.read_csv(os.path.join(args.result_dir, "lime_data_val.csv"))
        lime_explainer.load(lime_dataset)
    print("lime on test data:")
    if not os.path.exists(os.path.join(args.result_dir, "lime_data_test.csv")):
        lime_dataset = lime_explainer.build_explanations(model.predict, data_test)
        lime_dataset.to_csv(os.path.join(args.result_dir, "new_lime_data_test.csv"), header=True, index=False)
    else:
        lime_dataset = pd.read_csv(os.path.join(args.result_dir, "lime_data_test.csv"))
        lime_explainer.load(lime_dataset)

    lda_explainer = LdaExplainer(ngrams=model.get_model_max_ngrams(), dictionary=model.get_model_vocabulary())

    if not os.path.exists(os.path.join(args.result_dir, "lda_search_result.txt")):
        min_topics = 5
        max_topics = 20
        _, coherences = lda_explainer.search_lda(data_train, data_val, min_topics, max_topics)
        with open(os.path.join(args.result_dir, "lda_search_result.txt"), "w") as f:
            f.write(str(np.argmax(coherences) + min_topics))

        plt.clf()
        plt.plot(range(5, 21), coherences)
        plt.xlabel("num topics")
        plt.ylabel("coherences")
        plt.tight_layout()
        plt.savefig(os.path.join(args.result_dir, "lda_coherences.png"))
    else:
        lda_explainer.load_config(os.path.join(args.result_dir, "lda_search_result.txt"), data_train)

    cols = np.linspace(0, 360, lda_explainer.lda.num_topics)
    for i, weights in lda_explainer.lda.show_topics(num_topics=-1,
                                                    num_words=100,
                                                    formatted=False):
        maincol = cols[i]

        def colorfunc(word=None, font_size=None,
                      position=None, orientation=None,
                      font_path=None, random_state=None):
            color = randint(maincol - 10, maincol + 10)
            if color < 0:
                color = 360 + color
            return "hsl(%d, %d%%, %d%%)" % (color, randint(65, 75) + font_size / 7, randint(35, 45) - font_size / 10)

        wordcloud = WordCloud(background_color="white",
                              ranks_only=False,
                              max_font_size=120,
                              color_func=colorfunc,
                              height=600, width=800).generate_from_frequencies(dict(weights))

        plt.clf()
        plt.imshow(wordcloud, interpolation="bilinear")
        plt.axis("off")
        plt.savefig(os.path.join(args.result_dir, "word_clouds{}.png".format(i)))

    plt.clf()
    lime_explainer.get_word_importances().head(20)["score"].plot.bar()
    plt.ylabel("LIME score")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, "best_scores.png"))

    explainer = Explainer(data_test, model, lime_explainer, lda_explainer)
    affirmed_topic_importance, reversed_topic_importance = explainer.get_aggregated_topic_importance()
    plt.clf()
    pd.DataFrame({"topic_importance": affirmed_topic_importance}).plot.bar()
    plt.xlabel("Topic")
    plt.ylabel("Normalized score")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, "affirmed_topic_importance.png"))
    plt.clf()
    pd.DataFrame({"topic_importance": reversed_topic_importance}).plot.bar()
    plt.xlabel("Topic")
    plt.ylabel("Normalized score")
    plt.tight_layout()
    plt.savefig(os.path.join(args.result_dir, "reversed_topic_importance.png"))