コード例 #1
0
    def test_gazetteer_tensor(self):
        tensorizer = GazetteerTensorizer()

        data = TSVDataSource(
            train_file=SafeFileWrapper(
                tests_module.test_file("train_dict_features.tsv")),
            test_file=None,
            eval_file=None,
            field_names=["text", "dict"],
            schema={
                "text": str,
                "dict": Gazetteer
            },
        )

        init = tensorizer.initialize()
        init.send(None)  # kick
        for row in data.train:
            init.send(row)
        init.close()
        # UNK + PAD + 5 labels
        self.assertEqual(7, len(tensorizer.vocab))

        # only two rows in test file:
        # "Order coffee from Starbucks please"
        # "Order some fries from McDonalds please"
        for i, row in enumerate(data.train):
            if i == 0:
                idx, weights, lens = tensorizer.numberize(row)
                self.assertEqual([1, 1, 2, 3, 1, 1, 4, 1, 1, 1], idx)
                self.assertEqual(
                    [0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
                    weights)
                self.assertEqual([1, 2, 1, 1, 1], lens)
            if i == 1:
                idx, weights, lens = tensorizer.numberize(row)
                self.assertEqual([1, 1, 5, 1, 6, 1], idx)
                self.assertEqual([0.0, 0.0, 1.0, 0.0, 1.0, 0.0], weights)
                self.assertEqual([1, 1, 1, 1, 1, 1], lens)

        feats, weights, lens = tensorizer.tensorize(
            tensorizer.numberize(row) for row in data.train)
        self.assertEqual(
            [
                [1, 1, 2, 3, 1, 1, 4, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 5, 1, 1, 1, 6, 1, 1, 1],
            ],
            feats.numpy().tolist(),
        )
        self.assertEqual(
            str([
                [0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
            ]),
            str([[round(w, 2) for w in utt_weights]
                 for utt_weights in weights.numpy()]),
        )
        self.assertEqual([[1, 2, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]],
                         lens.numpy().tolist())
コード例 #2
0
ファイル: tensorizers_test.py プロジェクト: jjw-megha/pytext
    def test_gazetteer_tensor(self):
        tensorizer = GazetteerTensorizer()

        data = TSVDataSource(
            train_file=SafeFileWrapper(
                tests_module.test_file("train_dict_features.tsv")
            ),
            test_file=None,
            eval_file=None,
            field_names=["text", "dict"],
            schema={"text": str, "dict": Gazetteer},
        )

        init = tensorizer.initialize()
        init.send(None)  # kick
        for row in data.train:
            init.send(row)
        init.close()
        # UNK + PAD + 3 labels
        self.assertEqual(5, len(tensorizer.vocab))

        # only one row in test file:
        # "Order coffee from Starbucks please"
        for row in data.train:
            idx, weights, lens = tensorizer.numberize(row)
            self.assertEqual([1, 1, 2, 3, 1, 1, 4, 1, 1, 1], idx)
            self.assertEqual(
                [0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], weights
            )
            self.assertEqual([1, 2, 1, 1, 1], lens)