示例#1
0
 def test_raw_data_manager(self):
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 0)
     raw_data_rep0 = rdm.get_raw_data_rep_by_index(0)
     raw_data_rep1 = rdm.get_raw_data_rep_by_index(1)
     self.assertTrue(raw_data_rep0.HasField('unindexed'))
     self.assertTrue(raw_data_rep1.HasField('unindexed'))
     self.assertEqual(ntpath.basename(raw_data_rep0.raw_data_path),
                      'raw_data_0')
     self.assertEqual(ntpath.basename(raw_data_rep1.raw_data_path),
                      'raw_data_1')
     rdm.index_raw_data_rep(0, 0)
     self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 1)
     indexed_rep0 = rdm.get_indexed_raw_data_reps()[0]
     self.assertEqual(indexed_rep0.raw_data_path,
                      raw_data_rep0.raw_data_path)
     self.assertTrue(indexed_rep0.HasField('index'))
     self.assertEqual(indexed_rep0.index.start_index, 0)
     rdm.index_raw_data_rep(1, 100)
     self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 2)
     indexed_rep1 = rdm.get_indexed_raw_data_reps()[1]
     self.assertEqual(indexed_rep1.raw_data_path,
                      raw_data_rep1.raw_data_path)
     self.assertTrue(indexed_rep1.HasField('index'))
     self.assertEqual(indexed_rep1.index.start_index, 100)
示例#2
0
 def test_raw_data_manager(self):
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertEqual(len(rdm.get_index_metas()), 0)
     self.assertFalse(rdm.check_index_meta_by_process_index(0))
     self._gen_raw_data_file(0, 2)
     self.assertEqual(len(rdm.get_index_metas()), 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     self.assertTrue(rdm.check_index_meta_by_process_index(1))
     self.assertEqual(len(rdm.get_index_metas()), 0)
     partition_dir = os.path.join(self.data_source.raw_data_dir,
                                  common.partition_repr(0))
     index_meta0 = rdm.get_index_meta_by_index(0, 0)
     self.assertEqual(index_meta0.start_index, 0)
     self.assertEqual(index_meta0.process_index, 0)
     self.assertEqual(len(rdm.get_index_metas()), 1)
     index_meta1 = rdm.get_index_meta_by_index(1, 100)
     self.assertEqual(index_meta1.start_index, 100)
     self.assertEqual(index_meta1.process_index, 1)
     self.assertEqual(len(rdm.get_index_metas()), 2)
     self.assertFalse(rdm.check_index_meta_by_process_index(2))
     self._gen_raw_data_file(2, 4)
     self.assertTrue(rdm.check_index_meta_by_process_index(2))
     self.assertTrue(rdm.check_index_meta_by_process_index(3))
     index_meta2 = rdm.get_index_meta_by_index(2, 200)
     self.assertEqual(index_meta2.start_index, 200)
     self.assertEqual(index_meta2.process_index, 2)
     self.assertEqual(len(rdm.get_index_metas()), 3)
     index_meta3 = rdm.get_index_meta_by_index(3, 300)
     self.assertEqual(index_meta3.start_index, 300)
     self.assertEqual(index_meta3.process_index, 3)
     self.assertEqual(len(rdm.get_index_metas()), 4)
示例#3
0
 def test_csv_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)),
                                   "../csv_raw_data")
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(
         common.data_source_etcd_base_dir(
             self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(0, [
         dj_pb.RawDataMeta(file_path=path.join(partition_dir,
                                               "test_raw_data.csv"),
                           timestamp=timestamp_pb2.Timestamp(seconds=3))
     ], True)
     raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT',
                                             read_ahead_size=1 << 20)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 1024 == 0:
             print("{} {}".format(index, item.raw_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertEqual(expected_index, 4999)
 def test_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.data_source.raw_data_dir = "./test/compressed_raw_data"
     self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379',
                                        'fedlearner', True)
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.etcd, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_DATASET',
             compressed_type='GZIP'
         )
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
 def test_compressed_raw_data_visitor(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = path.join(
             path.dirname(path.abspath(__file__)), "../compressed_raw_data"
         )
     self.kvstore = DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = path.join(self.raw_data_dir, common.partition_repr(0))
     self.assertTrue(gfile.Exists(partition_dir))
     manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
     manifest_manager.add_raw_data(
             0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"),
                                   timestamp=timestamp_pb2.Timestamp(seconds=3))],
             True)
     raw_data_options = dj_pb.RawDataOptions(
             raw_data_iter='TF_RECORD',
             compressed_type='GZIP',
             read_ahead_size=1<<20,
             read_batch_size=128
         )
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0)
     self.assertTrue(rdm.check_index_meta_by_process_index(0))
     rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0,
                                           raw_data_options)
     expected_index = 0
     for (index, item) in rdv:
         if index > 0 and index % 32 == 0:
             print("{} {}".format(index, item.example_id))
         self.assertEqual(index, expected_index)
         expected_index += 1
     self.assertGreater(expected_index, 0)
示例#6
0
 def generate_leader_raw_data(self):
     dbm = data_block_manager.DataBlockManager(self.data_source_l, 0)
     raw_data_dir = os.path.join(self.data_source_l.raw_data_dir,
                                 common.partition_repr(0))
     if gfile.Exists(raw_data_dir):
         gfile.DeleteRecursively(raw_data_dir)
     gfile.MakeDirs(raw_data_dir)
     rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0)
     block_index = 0
     builder = create_data_block_builder(
         dj_pb.DataBlockBuilderOptions(
             data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
         self.data_source_l.raw_data_dir,
         self.data_source_l.data_source_meta.name, 0, block_index, None)
     process_index = 0
     start_index = 0
     for i in range(0, self.leader_end_index + 3):
         if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2):
             meta = builder.finish_data_block()
             if meta is not None:
                 ofname = common.encode_data_block_fname(
                     self.data_source_l.data_source_meta.name, meta)
                 fpath = os.path.join(raw_data_dir, ofname)
                 self.manifest_manager.add_raw_data(0, [
                     dj_pb.RawDataMeta(
                         file_path=fpath,
                         timestamp=timestamp_pb2.Timestamp(seconds=3))
                 ], False)
                 process_index += 1
                 start_index += len(meta.example_ids)
             block_index += 1
             builder = create_data_block_builder(
                 dj_pb.DataBlockBuilderOptions(
                     data_block_builder='TF_RECORD_DATABLOCK_BUILDER'),
                 self.data_source_l.raw_data_dir,
                 self.data_source_l.data_source_meta.name, 0, block_index,
                 None)
         feat = {}
         pt = i + 1 << 30
         if i % 3 == 0:
             pt = i // 3
         example_id = '{}'.format(pt).encode()
         feat['example_id'] = tf.train.Feature(
             bytes_list=tf.train.BytesList(value=[example_id]))
         event_time = 150000000 + pt
         feat['event_time'] = tf.train.Feature(
             int64_list=tf.train.Int64List(value=[event_time]))
         example = tf.train.Example(features=tf.train.Features(
             feature=feat))
         builder.append_record(example.SerializeToString(), example_id,
                               event_time, i, i)
     fpaths = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in fpaths:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
示例#7
0
 def generate_raw_data(self, begin_index, item_count):
     raw_data_dir = os.path.join(self.raw_data_dir,
                                 common.partition_repr(0))
     if not gfile.Exists(raw_data_dir):
         gfile.MakeDirs(raw_data_dir)
     self.total_raw_data_count += item_count
     useless_index = 0
     rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,
                                           0)
     fpaths = []
     for block_index in range(0, item_count // 2048):
         builder = DataBlockBuilder(
             self.raw_data_dir,
             self.data_source.data_source_meta.name, 0, block_index,
             dj_pb.WriterOptions(output_writer='TF_RECORD'), None)
         cands = list(
             range(begin_index + block_index * 2048,
                   begin_index + (block_index + 1) * 2048))
         start_index = cands[0]
         for i in range(len(cands)):
             if random.randint(1, 4) > 2:
                 continue
             a = random.randint(i - 32, i + 32)
             b = random.randint(i - 32, i + 32)
             if a < 0:
                 a = 0
             if a >= len(cands):
                 a = len(cands) - 1
             if b < 0:
                 b = 0
             if b >= len(cands):
                 b = len(cands) - 1
             if (abs(cands[a] - i - start_index) <= 32
                     and abs(cands[b] - i - start_index) <= 32):
                 cands[a], cands[b] = cands[b], cands[a]
         for example_idx in cands:
             feat = {}
             example_id = '{}'.format(example_idx).encode()
             feat['example_id'] = tf.train.Feature(
                 bytes_list=tf.train.BytesList(value=[example_id]))
             event_time = 150000000 + example_idx
             feat['event_time'] = tf.train.Feature(
                 int64_list=tf.train.Int64List(value=[event_time]))
             label = random.choice([1, 0])
             if random.random() < 0.8:
                 feat['label'] = tf.train.Feature(
                     int64_list=tf.train.Int64List(value=[label]))
             example = tf.train.Example(features=tf.train.Features(
                 feature=feat))
             builder.append_item(TfExampleItem(example.SerializeToString()),
                                 useless_index, useless_index)
             useless_index += 1
         meta = builder.finish_data_block()
         fname = common.encode_data_block_fname(
             self.data_source.data_source_meta.name, meta)
         fpath = os.path.join(raw_data_dir, fname)
         fpaths.append(
             dj_pb.RawDataMeta(
                 file_path=fpath,
                 timestamp=timestamp_pb2.Timestamp(seconds=3)))
         self.g_data_block_index += 1
     all_files = [
         os.path.join(raw_data_dir, f)
         for f in gfile.ListDirectory(raw_data_dir)
         if not gfile.IsDirectory(os.path.join(raw_data_dir, f))
     ]
     for fpath in all_files:
         if not fpath.endswith(common.DataBlockSuffix):
             gfile.Remove(fpath)
     self.manifest_manager.add_raw_data(0, fpaths, False)