class DocumentFrequenciesTests(unittest.TestCase): def setUp(self): self.model = DocumentFrequencies().load( source=os.path.join(os.path.dirname(__file__), paths.DOCFREQ)) def test_docs(self): docs = self.model.docs self.assertIsInstance(docs, int) self.assertEqual(docs, 1000) def test_get(self): self.assertEqual(self.model["aaaaaaa"], 341) with self.assertRaises(KeyError): print(self.model["xaaaaaa"]) self.assertEqual(self.model.get("aaaaaaa", 0), 341) self.assertEqual(self.model.get("xaaaaaa", 100500), 100500) def test_tokens(self): tokens = self.model.tokens() self.assertEqual(sorted(tokens), tokens) for t in tokens: self.assertGreater(self.model[t], 0) def test_len(self): # the remaining 18 are not unique - the model was generated badly self.assertEqual(len(self.model), 982) def test_iter(self): aaa = False for tok, freq in self.model: if "aaaaaaa" in tok: aaa = True int(freq) break self.assertTrue(aaa)
def test_preprocess(self): with tempfile.TemporaryDirectory() as tmpdir: args = default_preprocess_params(tmpdir, self.VOCAB) with captured_output() as (out, err, log): preprocess(args) self.assertFalse(out.getvalue()) self.assertFalse(err.getvalue()) self.assertIn("Skipped", log.getvalue()) self.assertIn("error.asdf", log.getvalue()) self.assertIn("empty_coocc.asdf", log.getvalue()) self.assertEqual(sorted(os.listdir(tmpdir)), [ "col_sums.txt", "col_vocab.txt", "docfreq.asdf", "row_sums.txt", "row_vocab.txt", "shard-000-000.pb" ]) df = DocumentFrequencies().load( source=os.path.join(tmpdir, "docfreq.asdf")) self.assertEqual(len(df), self.VOCAB) self.assertEqual(df.docs, len(os.listdir(args.input[0])) - 1) with open(os.path.join(tmpdir, "col_sums.txt")) as fin: col_sums = fin.read() with open(os.path.join(tmpdir, "row_sums.txt")) as fin: row_sums = fin.read() self.assertEqual(col_sums, row_sums) with open(os.path.join(tmpdir, "col_vocab.txt")) as fin: col_vocab = fin.read() with open(os.path.join(tmpdir, "row_vocab.txt")) as fin: row_vocab = fin.read() self.assertEqual(col_vocab, row_vocab) self.assertEqual(row_vocab.split("\n"), df.tokens()) for word in row_vocab.split("\n"): self.assertGreater(df[word], 0) with open(os.path.join(tmpdir, "shard-000-000.pb"), "rb") as fin: features = tf.parse_single_example( fin.read(), features={ "global_row": tf.FixedLenFeature([self.VOCAB], dtype=tf.int64), "global_col": tf.FixedLenFeature([self.VOCAB], dtype=tf.int64), "sparse_local_row": tf.VarLenFeature(dtype=tf.int64), "sparse_local_col": tf.VarLenFeature(dtype=tf.int64), "sparse_value": tf.VarLenFeature(dtype=tf.float32) }) with tf.Session() as session: global_row, global_col, local_row, local_col, value = session.run( [ features[n] for n in ("global_row", "global_col", "sparse_local_row", "sparse_local_col", "sparse_value") ]) self.assertEqual(set(range(self.VOCAB)), set(global_row)) self.assertEqual(set(range(self.VOCAB)), set(global_col)) nnz = 1421193 self.assertEqual(value.values.shape, (nnz, )) self.assertEqual(local_row.values.shape, (nnz, )) self.assertEqual(local_col.values.shape, (nnz, )) numpy.random.seed(0) all_tokens = row_vocab.split("\n") chosen_indices = numpy.random.choice(list(range(self.VOCAB)), 128, replace=False) chosen = [all_tokens[i] for i in chosen_indices] freqs = numpy.zeros((len(chosen), ) * 2, dtype=int) index = {w: i for i, w in enumerate(chosen)} chosen = set(chosen) for path in os.listdir(args.input[0]): with asdf.open(os.path.join(args.input[0], path)) as model: if model.tree["meta"]["model"] != "co-occurrences": continue matrix = assemble_sparse_matrix( model.tree["matrix"]).tocsr() tokens = split_strings(model.tree["tokens"]) interesting = { i for i, t in enumerate(tokens) if t in chosen } for y in interesting: row = matrix[y] yi = index[tokens[y]] for x, v in zip(row.indices, row.data): if x in interesting: freqs[yi, index[tokens[x]]] += v matrix = coo_matrix( (value.values, ([global_row[row] for row in local_row.values ], [global_col[col] for col in local_col.values])), shape=(self.VOCAB, self.VOCAB)) matrix = matrix.tocsr()[chosen_indices][:, chosen_indices].todense( ).astype(int) self.assertTrue((matrix == freqs).all())