Esempio n. 1
0
 def test_data_block_dumper(self):
     self.generate_follower_data_block()
     self.generate_leader_raw_data()
     options = customized_options.CustomizedOptions()
     options.set_raw_data_iter('TF_RECORD')
     dbd = data_block_dumper.DataBlockDumperManager(
             self.etcd, self.data_source_l, 0, options
         )
     self.assertEqual(dbd.get_partition_id(), 0)
     self.assertEqual(dbd.get_next_data_block_index(), 0)
     for (idx, meta) in enumerate(self.dumped_metas):
         success, next_index = dbd.append_synced_data_block_meta(meta)
         self.assertTrue(success)
         self.assertEqual(next_index, idx + 1)
     self.assertTrue(dbd.need_dump())
     self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas))
     dbd.dump_data_blocks()
     dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0)
     dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0)
     self.assertEqual(dbm_f.get_dumped_data_block_num(), len(self.dumped_metas))
     self.assertEqual(dbm_f.get_dumped_data_block_num(),
                         dbm_l.get_dumped_data_block_num())
     for (idx, meta) in enumerate(self.dumped_metas):
         self.assertEqual(meta.data_block_index, idx)
         self.assertEqual(dbm_l.get_data_block_meta_by_index(idx)[0], meta)
         self.assertEqual(dbm_f.get_data_block_meta_by_index(idx)[0], meta)
         block_id = meta.block_id
         meta_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockMetaSuffix)
         mitr = tf.io.tf_record_iterator(meta_fpth_l)
         meta_l = dj_pb.DataBlockMeta()
         meta_l.ParseFromString(next(mitr))
         self.assertEqual(meta_l, meta)
         meta_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockMetaSuffix)
         mitr = tf.io.tf_record_iterator(meta_fpth_f)
         meta_f = dj_pb.DataBlockMeta()
         meta_f.ParseFromString(next(mitr)) 
         self.assertEqual(meta_f, meta)
         data_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockSuffix)
         for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
         data_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0',
                                    block_id + common.DataBlockSuffix)
         for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
Esempio n. 2
0
 def test_data_block_dumper(self):
     self.generate_follower_data_block()
     self.generate_leader_raw_data()
     dbd = data_block_dumper.DataBlockDumperManager(
         self.etcd,
         self.data_source_l,
         0,
         dj_pb.RawDataOptions(raw_data_iter='TF_RECORD'),
         dj_pb.DataBlockBuilderOptions(
             data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
     )
     self.assertEqual(dbd.get_next_data_block_index(), 0)
     for (idx, meta) in enumerate(self.dumped_metas):
         success, next_index = dbd.add_synced_data_block_meta(meta)
         self.assertTrue(success)
         self.assertEqual(next_index, idx + 1)
     self.assertTrue(dbd.need_dump())
     self.assertEqual(dbd.get_next_data_block_index(),
                      len(self.dumped_metas))
     with dbd.make_data_block_dumper() as dumper:
         dumper()
     dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0)
     dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0)
     self.assertEqual(dbm_f.get_dumped_data_block_count(),
                      len(self.dumped_metas))
     self.assertEqual(dbm_f.get_dumped_data_block_count(),
                      dbm_l.get_dumped_data_block_count())
     for (idx, meta) in enumerate(self.dumped_metas):
         self.assertEqual(meta.data_block_index, idx)
         self.assertEqual(dbm_l.get_data_block_meta_by_index(idx), meta)
         self.assertEqual(dbm_f.get_data_block_meta_by_index(idx), meta)
         meta_fpth_l = os.path.join(
             self.data_source_l.data_block_dir, common.partition_repr(0),
             common.encode_data_block_meta_fname(
                 self.data_source_l.data_source_meta.name, 0,
                 meta.data_block_index))
         mitr = tf.io.tf_record_iterator(meta_fpth_l)
         meta_l = text_format.Parse(next(mitr), dj_pb.DataBlockMeta())
         self.assertEqual(meta_l, meta)
         meta_fpth_f = os.path.join(
             self.data_source_f.data_block_dir, common.partition_repr(0),
             common.encode_data_block_meta_fname(
                 self.data_source_f.data_source_meta.name, 0,
                 meta.data_block_index))
         mitr = tf.io.tf_record_iterator(meta_fpth_f)
         meta_f = text_format.Parse(next(mitr), dj_pb.DataBlockMeta())
         self.assertEqual(meta_f, meta)
         data_fpth_l = os.path.join(
             self.data_source_l.data_block_dir, common.partition_repr(0),
             common.encode_data_block_fname(
                 self.data_source_l.data_source_meta.name, meta_l))
         for (iidx,
              record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)
         data_fpth_f = os.path.join(
             self.data_source_f.data_block_dir, common.partition_repr(0),
             common.encode_data_block_fname(
                 self.data_source_l.data_source_meta.name, meta_f))
         for (iidx,
              record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)):
             example = tf.train.Example()
             example.ParseFromString(record)
             feat = example.features.feature
             self.assertEqual(feat['example_id'].bytes_list.value[0],
                              meta.example_ids[iidx])
         self.assertEqual(len(meta.example_ids), iidx + 1)