def test_raw_data_manager(self): manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0) self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 0) raw_data_rep0 = rdm.get_raw_data_rep_by_index(0) raw_data_rep1 = rdm.get_raw_data_rep_by_index(1) self.assertTrue(raw_data_rep0.HasField('unindexed')) self.assertTrue(raw_data_rep1.HasField('unindexed')) self.assertEqual(ntpath.basename(raw_data_rep0.raw_data_path), 'raw_data_0') self.assertEqual(ntpath.basename(raw_data_rep1.raw_data_path), 'raw_data_1') rdm.index_raw_data_rep(0, 0) self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 1) indexed_rep0 = rdm.get_indexed_raw_data_reps()[0] self.assertEqual(indexed_rep0.raw_data_path, raw_data_rep0.raw_data_path) self.assertTrue(indexed_rep0.HasField('index')) self.assertEqual(indexed_rep0.index.start_index, 0) rdm.index_raw_data_rep(1, 100) self.assertEqual(len(rdm.get_indexed_raw_data_reps()), 2) indexed_rep1 = rdm.get_indexed_raw_data_reps()[1] self.assertEqual(indexed_rep1.raw_data_path, raw_data_rep1.raw_data_path) self.assertTrue(indexed_rep1.HasField('index')) self.assertEqual(indexed_rep1.index.start_index, 100)
def test_raw_data_manager(self): rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0) self.assertEqual(len(rdm.get_index_metas()), 0) self.assertFalse(rdm.check_index_meta_by_process_index(0)) self._gen_raw_data_file(0, 2) self.assertEqual(len(rdm.get_index_metas()), 0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) self.assertTrue(rdm.check_index_meta_by_process_index(1)) self.assertEqual(len(rdm.get_index_metas()), 0) partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0)) index_meta0 = rdm.get_index_meta_by_index(0, 0) self.assertEqual(index_meta0.start_index, 0) self.assertEqual(index_meta0.process_index, 0) self.assertEqual(len(rdm.get_index_metas()), 1) index_meta1 = rdm.get_index_meta_by_index(1, 100) self.assertEqual(index_meta1.start_index, 100) self.assertEqual(index_meta1.process_index, 1) self.assertEqual(len(rdm.get_index_metas()), 2) self.assertFalse(rdm.check_index_meta_by_process_index(2)) self._gen_raw_data_file(2, 4) self.assertTrue(rdm.check_index_meta_by_process_index(2)) self.assertTrue(rdm.check_index_meta_by_process_index(3)) index_meta2 = rdm.get_index_meta_by_index(2, 200) self.assertEqual(index_meta2.start_index, 200) self.assertEqual(index_meta2.process_index, 2) self.assertEqual(len(rdm.get_index_metas()), 3) index_meta3 = rdm.get_index_meta_by_index(3, 300) self.assertEqual(index_meta3.start_index, 300) self.assertEqual(index_meta3.process_index, 3) self.assertEqual(len(rdm.get_index_metas()), 4)
def test_csv_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)), "../csv_raw_data") self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix( common.data_source_etcd_base_dir( self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"), timestamp=timestamp_pb2.Timestamp(seconds=3)) ], True) raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {}".format(index, item.raw_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(expected_index, 4999)
def test_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.data_source.raw_data_dir = "./test/compressed_raw_data" self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix(self.data_source.data_source_meta.name) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_DATASET', compressed_type='GZIP' ) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def test_compressed_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../compressed_raw_data" ) self.kvstore = DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='GZIP', read_ahead_size=1<<20, read_batch_size=128 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def generate_leader_raw_data(self): dbm = data_block_manager.DataBlockManager(self.data_source_l, 0) raw_data_dir = os.path.join(self.data_source_l.raw_data_dir, common.partition_repr(0)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0) block_index = 0 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) process_index = 0 start_index = 0 for i in range(0, self.leader_end_index + 3): if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2): meta = builder.finish_data_block() if meta is not None: ofname = common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, ofname) self.manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3)) ], False) process_index += 1 start_index += len(meta.example_ids) block_index += 1 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) feat = {} pt = i + 1 << 30 if i % 3 == 0: pt = i // 3 example_id = '{}'.format(pt).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + pt feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_record(example.SerializeToString(), example_id, event_time, i, i) fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath)
def generate_raw_data(self, begin_index, item_count): raw_data_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if not gfile.Exists(raw_data_dir): gfile.MakeDirs(raw_data_dir) self.total_raw_data_count += item_count useless_index = 0 rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source, 0) fpaths = [] for block_index in range(0, item_count // 2048): builder = DataBlockBuilder( self.raw_data_dir, self.data_source.data_source_meta.name, 0, block_index, dj_pb.WriterOptions(output_writer='TF_RECORD'), None) cands = list( range(begin_index + block_index * 2048, begin_index + (block_index + 1) * 2048)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - 32, i + 32) b = random.randint(i - 32, i + 32) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= 32 and abs(cands[b] - i - start_index) <= 32): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) label = random.choice([1, 0]) if random.random() < 0.8: feat['label'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[label])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), useless_index, useless_index) useless_index += 1 meta = builder.finish_data_block() fname = common.encode_data_block_fname( self.data_source.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, fname) fpaths.append( dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3))) self.g_data_block_index += 1 all_files = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in all_files: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath) self.manifest_manager.add_raw_data(0, fpaths, False)