def test_raw_data_visitor(self): manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest, finished = manifest_manager.alloc_unallocated_partition(2) self.assertFalse(finished) self.assertEqual(manifest.partition_id, 0) self.assertEqual(manifest.state, dj_pb.RawDataState.Syncing) self.assertEqual(manifest.allocated_rank_id, 2) options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_RECORD') rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, options) expected_index = 0 for (index, item) in rdv: self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(item.example_id, '{}'.format(index).encode()) try: rdv.seek(200) except StopIteration: self.assertTrue(True) self.assertEqual(rdv.get_current_index(), 199) else: self.assertFalse(False) index, item = rdv.seek(50) self.assertEqual(index, 50) self.assertEqual(item.example_id, '{}'.format(index).encode()) expected_index = index + 1 for (index, item) in rdv: self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(item.example_id, '{}'.format(index).encode())
def test_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.data_source.raw_data_dir = "./test/compressed_raw_data" self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner') self.etcd.delete_prefix(self.data_source.data_source_meta.name) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.data_source.raw_data_dir, 'partition_0') self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_DATASET') options.set_compressed_type('GZIP') rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {} {}".format(index, item.example_id, item.event_time)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def test_data_block_dumper(self): self.generate_follower_data_block() self.generate_leader_raw_data() options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_RECORD') dbd = data_block_dumper.DataBlockDumperManager( self.etcd, self.data_source_l, 0, options ) self.assertEqual(dbd.get_partition_id(), 0) self.assertEqual(dbd.get_next_data_block_index(), 0) for (idx, meta) in enumerate(self.dumped_metas): success, next_index = dbd.append_synced_data_block_meta(meta) self.assertTrue(success) self.assertEqual(next_index, idx + 1) self.assertTrue(dbd.need_dump()) self.assertEqual(dbd.get_next_data_block_index(), len(self.dumped_metas)) dbd.dump_data_blocks() dbm_f = data_block_manager.DataBlockManager(self.data_source_f, 0) dbm_l = data_block_manager.DataBlockManager(self.data_source_l, 0) self.assertEqual(dbm_f.get_dumped_data_block_num(), len(self.dumped_metas)) self.assertEqual(dbm_f.get_dumped_data_block_num(), dbm_l.get_dumped_data_block_num()) for (idx, meta) in enumerate(self.dumped_metas): self.assertEqual(meta.data_block_index, idx) self.assertEqual(dbm_l.get_data_block_meta_by_index(idx)[0], meta) self.assertEqual(dbm_f.get_data_block_meta_by_index(idx)[0], meta) block_id = meta.block_id meta_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0', block_id + common.DataBlockMetaSuffix) mitr = tf.io.tf_record_iterator(meta_fpth_l) meta_l = dj_pb.DataBlockMeta() meta_l.ParseFromString(next(mitr)) self.assertEqual(meta_l, meta) meta_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0', block_id + common.DataBlockMetaSuffix) mitr = tf.io.tf_record_iterator(meta_fpth_f) meta_f = dj_pb.DataBlockMeta() meta_f.ParseFromString(next(mitr)) self.assertEqual(meta_f, meta) data_fpth_l = os.path.join(self.data_source_l.data_block_dir, 'partition_0', block_id + common.DataBlockSuffix) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_l)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1) data_fpth_f = os.path.join(self.data_source_f.data_block_dir, 'partition_0', block_id + common.DataBlockSuffix) for (iidx, record) in enumerate(tf.io.tf_record_iterator(data_fpth_f)): example = tf.train.Example() example.ParseFromString(record) feat = example.features.feature self.assertEqual(feat['example_id'].bytes_list.value[0], meta.example_ids[iidx]) self.assertEqual(len(meta.example_ids), iidx + 1)
def test_all_assembly(self): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( self.data_source_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}') self.generate_raw_data( self.data_source_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}') worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' options = customized_options.CustomizedOptions() options.set_raw_data_iter('TF_RECORD') options.set_example_joiner('STREAM_JOINER') worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, options) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, options) worker_l.start() worker_f.start() while True: rsp_l = self.master_client_l.GetDataSourceState( self.data_source_l.data_source_meta) rsp_f = self.master_client_f.GetDataSourceState( self.data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) worker_l.stop() worker_f.stop() self.master_l.stop() self.master_f.stop()
def test_example_join(self): self.generate_raw_data() self.generate_example_id() options = customized_options.CustomizedOptions() options.set_example_joiner('STREAM_JOINER') sei = joiner_impl.create_example_joiner(options, self.etcd, self.data_source, 0, options) sei.join_example() self.assertTrue(sei.join_finished()) dbm = data_block_manager.DataBlockManager(self.data_source, 0) data_block_num = dbm.get_dumped_data_block_num() join_count = 0 for data_block_index in range(data_block_num): meta = dbm.get_data_block_meta_by_index(data_block_index)[0] self.assertTrue(meta is not None) join_count += len(meta.example_ids) print("join rate {}/{}({}), min_matching_window {}, "\ "max_matching_window {}".format( join_count, self.total_index, (join_count+.0)/self.total_index, self.data_source.data_source_meta.min_matching_window, self.data_source.data_source_meta.max_matching_window))
type=str, default='TF_RECORD', help='the type for raw data file') parser.add_argument('--example_joiner', type=str, default='STREAM_JOINER', help='the method for example joiner') parser.add_argument('--compressed_type', type=str, choices=['', 'ZLIB', 'GZIP'], help='the compressed type for raw data') parser.add_argument('--tf_eager_mode', action='store_true', help='use the eager_mode for tf') args = parser.parse_args() cst_options = customized_options.CustomizedOptions() if args.raw_data_iter is not None: cst_options.set_raw_data_iter(args.raw_data_iter) if args.example_joiner is not None: cst_options.set_example_joiner(args.example_joiner) if args.compressed_type is not None: cst_options.set_compressed_type(args.compressed_type) if args.tf_eager_mode: import tensorflow tensorflow.compat.v1.enable_eager_execution() worker_srv = DataJoinWorkerService(args.listen_port, args.peer_addr, args.master_addr, args.rank_id, args.etcd_name, args.etcd_base_dir, args.etcd_addrs, cst_options) worker_srv.run()