예제 #1
0
    def test_read_and_writer_pb(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        for i in xrange(20):
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.word = i
            for j in xrange(20):
                non_zero = word_topic_hist.sparse_topic_hist.non_zeros.add()
                non_zero.topic = j
                non_zero.count = j + 1
            self.assertTrue(
                    record_writer.write(word_topic_hist.SerializeToString()))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        i = 0
        while True:
            blob = record_reader.read()
            if blob == None:
                break
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.ParseFromString(blob)
            self.assertEqual(i, word_topic_hist.word)
            sparse_topic_hist = word_topic_hist.sparse_topic_hist
            self.assertEqual(20, len(sparse_topic_hist.non_zeros))
            for j in xrange(len(sparse_topic_hist.non_zeros)):
                self.assertEqual(j, sparse_topic_hist.non_zeros[j].topic)
                self.assertEqual(j + 1, sparse_topic_hist.non_zeros[j].count)
            i += 1
        self.assertEqual(20, i)
        fp.close()
예제 #2
0
    def test_read_and_writer_pb(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        for i in xrange(20):
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.word = i
            for j in xrange(20):
                non_zero = word_topic_hist.sparse_topic_hist.non_zeros.add()
                non_zero.topic = j
                non_zero.count = j + 1
            self.assertTrue(
                record_writer.write(word_topic_hist.SerializeToString()))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        i = 0
        while True:
            blob = record_reader.read()
            if blob == None:
                break
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.ParseFromString(blob)
            self.assertEqual(i, word_topic_hist.word)
            sparse_topic_hist = word_topic_hist.sparse_topic_hist
            self.assertEqual(20, len(sparse_topic_hist.non_zeros))
            for j in xrange(len(sparse_topic_hist.non_zeros)):
                self.assertEqual(j, sparse_topic_hist.non_zeros[j].topic)
                self.assertEqual(j + 1, sparse_topic_hist.non_zeros[j].count)
            i += 1
        self.assertEqual(20, i)
        fp.close()
예제 #3
0
    def _load_hyper_params(self, filename):
        logging.info('Loading hyper_params topic_prior and word_prior.')
        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        blob = record_reader.read()
        fp.close()
        if blob == None:
            logging.error('HyperParams is nil, file %s' % filename)
            return False

        self.hyper_params.parse_from_string(blob)
        return True
예제 #4
0
    def _load_hyper_params(self, filename):
        logging.info('Loading hyper_params topic_prior and word_prior.')
        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        blob = record_reader.read()
        fp.close()
        if blob == None:
            logging.error('HyperParams is nil, file %s' % filename)
            return False

        self.hyper_params.parse_from_string(blob)
        return True
예제 #5
0
    def testReader(self):
        self.assertRaises(IOError, RecordReader, "")
        self.assertTrue(os.path.exists(self.test_file))
        record_reader = RecordReader(self.test_file)

        ret, data = record_reader.read_record()
        self.assertTrue(ret)
        self.assertEqual(data, self.record_one)

        ret, data = record_reader.read_record()
        self.assertTrue(ret)
        self.assertEqual(data, self.record_two)

        ret, data = record_reader.read_record()
        self.assertFalse(ret)
예제 #6
0
파일: model.py 프로젝트: springbarley/mltk
    def _load_global_topic_hist(self, filename):
        logging.info('Loading global_topic_hist vector N(z).')

        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        blob = record_reader.read()
        fp.close()
        if blob == None:
            logging.error('GlobalTopicHist is nil, file %s' % filename)
            return False

        global_topic_hist_pb = GlobalTopicHistogramPB()
        global_topic_hist_pb.ParseFromString(blob)
        self.global_topic_hist = np.array(global_topic_hist_pb.topic_counts, dtype = 'int64')

        return True
예제 #7
0
    def _load_global_topic_hist(self, filename):
        logging.info('Loading global_topic_hist vector N(z).')
        self.global_topic_hist = []

        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        blob = record_reader.read()
        fp.close()
        if blob == None:
            logging.error('GlobalTopicHist is nil, file %s' % filename)
            return False

        global_topic_hist_pb = GlobalTopicHistogramPB()
        global_topic_hist_pb.ParseFromString(blob)
        for topic_count in global_topic_hist_pb.topic_counts:
            self.global_topic_hist.append(topic_count)
        return True
예제 #8
0
    def _load_word_topic_hist(self, filename):
        logging.info('Loading word_topic_hist matrix N(w|z).')
        self.word_topic_hist.clear()

        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        while True:
            blob = record_reader.read()
            if blob == None:
                break

            word_topic_hist_pb = WordTopicHistogramPB()
            word_topic_hist_pb.ParseFromString(blob)

            ordered_sparse_topic_hist = \
                    OrderedSparseTopicHistogram(self.num_topics)
            ordered_sparse_topic_hist.parse_from_string(
                word_topic_hist_pb.sparse_topic_hist.SerializeToString())
            self.word_topic_hist[word_topic_hist_pb.word] = \
                    ordered_sparse_topic_hist
        fp.close()
        return (len(self.word_topic_hist) > 0)
예제 #9
0
    def _load_word_topic_hist(self, filename):
        logging.info('Loading word_topic_hist matrix N(w|z).')
        self.word_topic_hist.clear()

        fp = open(filename, "rb")
        record_reader = RecordReader(fp)
        while True:
            blob = record_reader.read()
            if blob == None:
                break

            word_topic_hist_pb = WordTopicHistogramPB()
            word_topic_hist_pb.ParseFromString(blob)

            ordered_sparse_topic_hist = \
                    OrderedSparseTopicHistogram(self.num_topics)
            ordered_sparse_topic_hist.parse_from_string(
                    word_topic_hist_pb.sparse_topic_hist.SerializeToString())
            self.word_topic_hist[word_topic_hist_pb.word] = \
                    ordered_sparse_topic_hist
        fp.close()
        return (len(self.word_topic_hist) > 0)
예제 #10
0
    def test_read_and_write_normal(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write('111'))
        self.assertTrue(record_writer.write('89'))
        self.assertTrue(record_writer.write('apple'))
        self.assertTrue(record_writer.write('ipad'))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        self.assertEqual('111', record_reader.read())
        self.assertEqual('89', record_reader.read())
        self.assertEqual('apple', record_reader.read())
        self.assertEqual('ipad', record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()
예제 #11
0
    def test_read_and_write_normal(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write('111'))
        self.assertTrue(record_writer.write('89'))
        self.assertTrue(record_writer.write('apple'))
        self.assertTrue(record_writer.write('ipad'))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        self.assertEqual('111', record_reader.read())
        self.assertEqual('89', record_reader.read())
        self.assertEqual('apple', record_reader.read())
        self.assertEqual('ipad', record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()
예제 #12
0
    def test_read_and_write_normal(self):
        fp = open("../testdata/recordio.dat", "wb")
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write("111"))
        self.assertTrue(record_writer.write("89"))
        self.assertTrue(record_writer.write("apple"))
        self.assertTrue(record_writer.write("ipad"))
        fp.close()

        fp = open("../testdata/recordio.dat", "rb")
        record_reader = RecordReader(fp)
        self.assertEqual("111", record_reader.read())
        self.assertEqual("89", record_reader.read())
        self.assertEqual("apple", record_reader.read())
        self.assertEqual("ipad", record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()