Exemplo n.º 1
0
    def _save_hyper_params(self, filename):
        fp = open(filename, 'wb')
        record_writer = RecordWriter(fp)

        record_writer.write(self.hyper_params.serialize_to_string())

        fp.close()
Exemplo n.º 2
0
    def test_read_and_writer_pb(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        for i in xrange(20):
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.word = i
            for j in xrange(20):
                non_zero = word_topic_hist.sparse_topic_hist.non_zeros.add()
                non_zero.topic = j
                non_zero.count = j + 1
            self.assertTrue(
                record_writer.write(word_topic_hist.SerializeToString()))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        i = 0
        while True:
            blob = record_reader.read()
            if blob == None:
                break
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.ParseFromString(blob)
            self.assertEqual(i, word_topic_hist.word)
            sparse_topic_hist = word_topic_hist.sparse_topic_hist
            self.assertEqual(20, len(sparse_topic_hist.non_zeros))
            for j in xrange(len(sparse_topic_hist.non_zeros)):
                self.assertEqual(j, sparse_topic_hist.non_zeros[j].topic)
                self.assertEqual(j + 1, sparse_topic_hist.non_zeros[j].count)
            i += 1
        self.assertEqual(20, i)
        fp.close()
Exemplo n.º 3
0
    def test_read_and_writer_pb(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        for i in xrange(20):
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.word = i
            for j in xrange(20):
                non_zero = word_topic_hist.sparse_topic_hist.non_zeros.add()
                non_zero.topic = j
                non_zero.count = j + 1
            self.assertTrue(
                    record_writer.write(word_topic_hist.SerializeToString()))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        i = 0
        while True:
            blob = record_reader.read()
            if blob == None:
                break
            word_topic_hist = WordTopicHistogramPB()
            word_topic_hist.ParseFromString(blob)
            self.assertEqual(i, word_topic_hist.word)
            sparse_topic_hist = word_topic_hist.sparse_topic_hist
            self.assertEqual(20, len(sparse_topic_hist.non_zeros))
            for j in xrange(len(sparse_topic_hist.non_zeros)):
                self.assertEqual(j, sparse_topic_hist.non_zeros[j].topic)
                self.assertEqual(j + 1, sparse_topic_hist.non_zeros[j].count)
            i += 1
        self.assertEqual(20, i)
        fp.close()
Exemplo n.º 4
0
 def _save_global_topic_hist(self, filename):
     fp = open(filename, 'wb')
     record_writer = RecordWriter(fp)
     global_topic_hist_pb = GlobalTopicHistogramPB()
     for topic_count in self.global_topic_hist:
         global_topic_hist_pb.topic_counts.append(topic_count)
     record_writer.write(global_topic_hist_pb.SerializeToString())
     fp.close()
Exemplo n.º 5
0
 def _save_global_topic_hist(self, filename):
     fp = open(filename, 'wb')
     record_writer = RecordWriter(fp)
     global_topic_hist_pb = GlobalTopicHistogramPB()
     for topic_count in self.global_topic_hist:
         global_topic_hist_pb.topic_counts.append(topic_count)
     record_writer.write(global_topic_hist_pb.SerializeToString())
     fp.close()
Exemplo n.º 6
0
 def _save_word_topic_hist(self, filename):
     fp = open(filename, 'wb')
     record_writer = RecordWriter(fp)
     for word, ordered_sparse_topic_hist in self.word_topic_hist.iteritems():
         word_topic_hist_pb = WordTopicHistogramPB()
         word_topic_hist_pb.word = word
         word_topic_hist_pb.sparse_topic_hist.ParseFromString(
                 ordered_sparse_topic_hist.serialize_to_string())
         record_writer.write(word_topic_hist_pb.SerializeToString())
     fp.close()
Exemplo n.º 7
0
 def _save_word_topic_hist(self, filename):
     fp = open(filename, 'wb')
     record_writer = RecordWriter(fp)
     for word, ordered_sparse_topic_hist in self.word_topic_hist.iteritems(
     ):
         word_topic_hist_pb = WordTopicHistogramPB()
         word_topic_hist_pb.word = word
         word_topic_hist_pb.sparse_topic_hist.ParseFromString(
             ordered_sparse_topic_hist.serialize_to_string())
         record_writer.write(word_topic_hist_pb.SerializeToString())
     fp.close()
Exemplo n.º 8
0
 def setUp(self):
     self.test_file = 'test.txt'
     if os.path.exists(self.test_file):
         os.remove(self.test_file)
     self.record_one = "hello, world"
     # a serialized protobuf message Record
     self.record_two = '\x08\x05\x11\x00\x00\x00\x00\x00\x00\x0c@\x1a\x05hello'
     record_writer = RecordWriter(self.test_file)
     self.assertTrue(record_writer.write_record(self.record_one))
     self.assertTrue(record_writer.write_record(self.record_two))
     del record_writer
Exemplo n.º 9
0
    def test_read_and_write_normal(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write('111'))
        self.assertTrue(record_writer.write('89'))
        self.assertTrue(record_writer.write('apple'))
        self.assertTrue(record_writer.write('ipad'))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        self.assertEqual('111', record_reader.read())
        self.assertEqual('89', record_reader.read())
        self.assertEqual('apple', record_reader.read())
        self.assertEqual('ipad', record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()
Exemplo n.º 10
0
    def test_read_and_write_normal(self):
        fp = open('../testdata/recordio.dat', 'wb')
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write('111'))
        self.assertTrue(record_writer.write('89'))
        self.assertTrue(record_writer.write('apple'))
        self.assertTrue(record_writer.write('ipad'))
        fp.close()

        fp = open('../testdata/recordio.dat', 'rb')
        record_reader = RecordReader(fp)
        self.assertEqual('111', record_reader.read())
        self.assertEqual('89', record_reader.read())
        self.assertEqual('apple', record_reader.read())
        self.assertEqual('ipad', record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()
Exemplo n.º 11
0
    def test_read_and_write_normal(self):
        fp = open("../testdata/recordio.dat", "wb")
        record_writer = RecordWriter(fp)
        self.assertFalse(record_writer.write(111))
        self.assertFalse(record_writer.write(111.89))
        self.assertFalse(record_writer.write(True))
        self.assertTrue(record_writer.write("111"))
        self.assertTrue(record_writer.write("89"))
        self.assertTrue(record_writer.write("apple"))
        self.assertTrue(record_writer.write("ipad"))
        fp.close()

        fp = open("../testdata/recordio.dat", "rb")
        record_reader = RecordReader(fp)
        self.assertEqual("111", record_reader.read())
        self.assertEqual("89", record_reader.read())
        self.assertEqual("apple", record_reader.read())
        self.assertEqual("ipad", record_reader.read())
        self.assertIsNone(record_reader.read())
        fp.close()
Exemplo n.º 12
0
 def _save_hyper_params(self, filename):
     fp = open(filename, 'wb')
     record_writer = RecordWriter(fp)
     record_writer.write(self.hyper_params.serialize_to_string())
     fp.close()