コード例 #1
0
ファイル: test_df.py プロジェクト: sniperkit/ml
class DocumentFrequenciesTests(unittest.TestCase):
    def setUp(self):
        self.model = DocumentFrequencies().load(source=paths.DOCFREQ)

    def test_docs(self):
        docs = self.model.docs
        self.assertIsInstance(docs, int)
        self.assertEqual(docs, 1000)

    def test_get(self):
        self.assertEqual(self.model["aaaaaaa"], 341)
        with self.assertRaises(KeyError):
            print(self.model["xaaaaaa"])
        self.assertEqual(self.model.get("aaaaaaa", 0), 341)
        self.assertEqual(self.model.get("xaaaaaa", 100500), 100500)

    def test_tokens(self):
        self.assertEqual(list(self.model._df), self.model.tokens())

    def test_len(self):
        # the remaining 18 are not unique - the model was generated badly
        self.assertEqual(len(self.model), 982)

    def test_iter(self):
        aaa = False
        for tok, freq in self.model:
            if "aaaaaaa" in tok:
                aaa = True
                int(freq)
                break
        self.assertTrue(aaa)

    def test_prune(self):
        pruned = self.model.prune(4)
        for tok, freq in pruned:
            self.assertGreaterEqual(freq, 4)
        self.assertEqual(len(pruned), 346)

    def test_prune_self(self):
        pruned = self.model.prune(1)
        self.assertIs(self.model, pruned)

    def test_greatest(self):
        pruned = self.model.greatest(100)
        freqs = [v for v in self.model._df.values()]
        freqs.sort(reverse=True)
        border = freqs[100]
        for v in pruned._df.values():
            self.assertGreaterEqual(v, border)
        df1 = pruned._df
        df2 = self.model.greatest(100)._df
        self.assertEqual(df1, df2)

    def test_write(self):
        buffer = BytesIO()
        self.model.save(buffer)
        buffer.seek(0)
        new_model = DocumentFrequencies().load(buffer)
        self.assertEqual(self.model._df, new_model._df)
        self.assertEqual(self.model.docs, new_model.docs)
コード例 #2
0
def projector_entry(args):
    MAX_TOKENS = 10000  # hardcoded in Tensorflow Projector

    log = logging.getLogger("id2vec_projector")
    id2vec = Id2Vec(log_level=args.log_level).load(source=args.input)
    if args.docfreq:
        from sourced.ml.models import DocumentFrequencies
        df = DocumentFrequencies(log_level=args.log_level).load(source=args.docfreq)
    else:
        df = None
    if len(id2vec) < MAX_TOKENS:
        tokens = numpy.arange(len(id2vec), dtype=int)
        if df is not None:
            freqs = [df.get(id2vec.tokens[i], 0) for i in tokens]
        else:
            freqs = None
    else:
        if df is not None:
            log.info("Filtering tokens through docfreq")
            items = []
            for token, idx in id2vec.items():
                try:
                    items.append((df[token], idx))
                except KeyError:
                    continue
            log.info("Sorting")
            items.sort(reverse=True)
            tokens = [i[1] for i in items[:MAX_TOKENS]]
            freqs = [i[0] for i in items[:MAX_TOKENS]]
        else:
            log.warning("You have not specified --df => picking random %d tokens", MAX_TOKENS)
            numpy.random.seed(777)
            tokens = numpy.random.choice(
                numpy.arange(len(id2vec), dtype=int), MAX_TOKENS, replace=False)
            freqs = None
    log.info("Gathering the embeddings")
    embeddings = numpy.vstack([id2vec.embeddings[i] for i in tokens])
    tokens = [id2vec.tokens[i] for i in tokens]
    labels = ["subtoken"]
    if freqs is not None:
        labels.append("docfreq")
        tokens = list(zip(tokens, (str(i) for i in freqs)))
    import sourced.ml.utils.projector as projector
    projector.present_embeddings(args.output, not args.no_browser, labels, tokens, embeddings)
    if not args.no_browser:
        projector.wait()