def test_csv_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join(path.dirname(path.abspath(__file__)), "../csv_raw_data") self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix( common.data_source_etcd_base_dir( self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta(file_path=path.join(partition_dir, "test_raw_data.csv"), timestamp=timestamp_pb2.Timestamp(seconds=3)) ], True) raw_data_options = dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', read_ahead_size=1 << 20) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source, 0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 1024 == 0: print("{} {}".format(index, item.raw_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertEqual(expected_index, 4999)
def test_compressed_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = path.join( path.dirname(path.abspath(__file__)), "../compressed_raw_data" ) self.kvstore = DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = path.join(self.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', compressed_type='GZIP', read_ahead_size=1<<20, read_batch_size=128 ) rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.kvstore, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def test_raw_data_visitor(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.data_source.raw_data_dir = "./test/compressed_raw_data" self.etcd = etcd_client.EtcdClient('test_cluster', 'localhost:2379', 'fedlearner', True) self.etcd.delete_prefix(self.data_source.data_source_meta.name) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.data_source.raw_data_dir, common.partition_repr(0)) self.assertTrue(gfile.Exists(partition_dir)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.etcd, self.data_source) manifest_manager.add_raw_data( 0, [dj_pb.RawDataMeta(file_path=os.path.join(partition_dir, "0-0.idx"), timestamp=timestamp_pb2.Timestamp(seconds=3))], True) raw_data_options = dj_pb.RawDataOptions( raw_data_iter='TF_DATASET', compressed_type='GZIP' ) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source,0) self.assertTrue(rdm.check_index_meta_by_process_index(0)) rdv = raw_data_visitor.RawDataVisitor(self.etcd, self.data_source, 0, raw_data_options) expected_index = 0 for (index, item) in rdv: if index > 0 and index % 32 == 0: print("{} {}".format(index, item.example_id)) self.assertEqual(index, expected_index) expected_index += 1 self.assertGreater(expected_index, 0)
def add_raw_data(self, partition_id, fpaths, dedup, timestamps=None): self._check_partition_id(partition_id) if not fpaths: raise RuntimeError("no files input") if timestamps is not None and len(fpaths) != len(timestamps): raise RuntimeError("the number of raw data file "\ "and timestamp mismatch") rdreq = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=dedup ) ) for index, fpath in enumerate(fpaths): if not gfile.Exists(fpath): raise ValueError('{} is not existed' % format(fpath)) raw_data_meta = dj_pb.RawDataMeta( file_path=fpath, start_index=-1 ) if timestamps is not None: raw_data_meta.timestamp.MergeFrom(timestamps[index]) rdreq.added_raw_data_metas.raw_data_metas.append(raw_data_meta) return self._master_client.AddRawData(rdreq)
def _preload_raw_data_meta(self): manifest_kvstore_key = common.partition_manifest_kvstore_key( self._data_source.data_source_meta.name, self._partition_id) all_metas = [] index_metas = [] for key, val in self._kvstore.get_prefix_kvs(manifest_kvstore_key, True): bkey = os.path.basename(key) if not bkey.decode().startswith(common.RawDataMetaPrefix): continue index = int(bkey[len(common.RawDataMetaPrefix):]) meta = text_format.Parse(val, dj_pb.RawDataMeta()) all_metas.append((index, meta)) if meta.start_index != -1: index_meta = visitor.IndexMeta(index, meta.start_index, meta.file_path) index_metas.append(index_meta) all_metas = sorted(all_metas, key=lambda meta: meta[0]) for process_index, meta in enumerate(all_metas): if process_index != meta[0]: logging.fatal("process_index mismatch with index %d != %d "\ "for file path %s", process_index, meta[0], meta[1].file_path) traceback.print_stack() os._exit(-1) # pylint: disable=protected-access return all_metas, index_metas
def _gen_raw_data_file(self, start_index, end_index, no_data=False): partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) fpaths = [] for i in range(start_index, end_index): if no_data: fname = "{}.no_data".format(i) else: fname = "{}{}".format(i, common.RawDataFileSuffix) fpath = os.path.join(partition_dir, fname) fpaths.append( dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3))) writer = tf.io.TFRecordWriter(fpath) if not no_data: for j in range(100): feat = {} example_id = '{}'.format(i * 100 + j).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) example = tf.train.Example(features=tf.train.Features( feature=feat)) writer.write(example.SerializeToString()) writer.close() self.manifest_manager.add_raw_data(0, fpaths, True)
def publish_raw_data(self, partition_id, fpaths, timestamps=None): if not fpaths: logging.warning("no raw data will be published") return if timestamps is not None and len(fpaths) != len(timestamps): raise RuntimeError("the number of raw data file "\ "and timestamp mismatch") new_raw_data_pubs = [] for index, fpath in enumerate(fpaths): if not gfile.Exists(fpath): raise ValueError('{} is not existed'.format(fpath)) raw_data_pub = dj_pb.RawDatePub(raw_data_meta=dj_pb.RawDataMeta( file_path=fpath, start_index=-1)) if timestamps is not None: raw_data_pub.raw_data_meta.timestamp.MergeFrom( timestamps[index]) new_raw_data_pubs.append(raw_data_pub) next_pub_index = None item_index = 0 data = text_format.MessageToString(new_raw_data_pubs[item_index]) while item_index < len(new_raw_data_pubs): next_pub_index = self._forward_pub_index(partition_id, next_pub_index) etcd_key = common.raw_data_pub_etcd_key(self._raw_data_pub_dir, partition_id, next_pub_index) if self._etcd.cas(etcd_key, None, data): logging.info("Success publish %s at index %d for partition"\ "%d", data, next_pub_index, partition_id) next_pub_index += 1 item_index += 1 if item_index < len(new_raw_data_pubs): raw_data_pub = new_raw_data_pubs[item_index] data = text_format.MessageToString(raw_data_pub)
def generate_leader_raw_data(self): dbm = data_block_manager.DataBlockManager(self.data_source_l, 0) raw_data_dir = os.path.join(self.data_source_l.raw_data_dir, common.partition_repr(0)) if gfile.Exists(raw_data_dir): gfile.DeleteRecursively(raw_data_dir) gfile.MakeDirs(raw_data_dir) rdm = raw_data_visitor.RawDataManager(self.etcd, self.data_source_l, 0) block_index = 0 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) process_index = 0 start_index = 0 for i in range(0, self.leader_end_index + 3): if (i > 0 and i % 2048 == 0) or (i == self.leader_end_index + 2): meta = builder.finish_data_block() if meta is not None: ofname = common.encode_data_block_fname( self.data_source_l.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, ofname) self.manifest_manager.add_raw_data(0, [ dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3)) ], False) process_index += 1 start_index += len(meta.example_ids) block_index += 1 builder = create_data_block_builder( dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER'), self.data_source_l.raw_data_dir, self.data_source_l.data_source_meta.name, 0, block_index, None) feat = {} pt = i + 1 << 30 if i % 3 == 0: pt = i // 3 example_id = '{}'.format(pt).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + pt feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_record(example.SerializeToString(), example_id, event_time, i, i) fpaths = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in fpaths: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath)
def _sync_raw_data_meta(self, process_index): etcd_key = common.raw_data_meta_etcd_key( self._data_source.data_source_meta.name, self._partition_id, process_index) data = self._etcd.get_data(etcd_key) if data is not None: return text_format.Parse(data, dj_pb.RawDataMeta()) return None
def _sync_raw_data_meta(self, process_index): kvstore_key = common.raw_data_meta_kvstore_key( self._data_source.data_source_meta.name, self._partition_id, process_index) data = self._kvstore.get_data(kvstore_key) if data is not None: return text_format.Parse(data, dj_pb.RawDataMeta(), allow_unknown_field=True) return None
def _new_index_meta(self, process_index, start_index): if self._manifest.next_process_index <= process_index: return None raw_data_meta = None if process_index < len(self._all_metas): assert process_index == self._all_metas[process_index][0], \ "process index should equal {} != {}".format( process_index, self._all_metas[process_index][0] ) raw_data_meta = self._all_metas[process_index][1] else: assert process_index == len(self._all_metas), \ "the process index should be the next all metas "\ "{}(process_index) != {}(size of all_metas)".format( process_index, len(self._all_metas) ) raw_data_meta = self._sync_raw_data_meta(process_index) if raw_data_meta is None: logging.fatal("the raw data of partition %d index with "\ "%d must in etcd", self._partition_id, process_index) traceback.print_stack() os._exit(-1) # pylint: disable=protected-access self._all_metas.append((process_index, raw_data_meta)) if raw_data_meta.start_index == -1: new_meta = dj_pb.RawDataMeta() new_meta.MergeFrom(raw_data_meta) new_meta.start_index = start_index odata = text_format.MessageToString(raw_data_meta) ndata = text_format.MessageToString(new_meta) etcd_key = common.raw_data_meta_etcd_key( self._data_source.data_source_meta.name, self._partition_id, process_index ) if not self._etcd.cas(etcd_key, odata, ndata): raw_data_meta = self._sync_raw_data_meta(process_index) assert raw_data_meta is not None, \ "the raw data meta of process index {} "\ "must not None".format(process_index) if raw_data_meta.start_index != start_index: logging.fatal("raw data of partition %d index with "\ "%d must start with %d", self._partition_id, process_index, start_index) traceback.print_stack() os._exit(-1) # pylint: disable=protected-access return visitor.IndexMeta(process_index, start_index, raw_data_meta.file_path)
def __init__(self, etcd, raw_data_options, mock_data_source_name, input_fpaths): mock_data_source = common_pb.DataSource( state=common_pb.DataSourceState.Processing, data_source_meta=common_pb.DataSourceMeta( name=mock_data_source_name, partition_num=1)) mock_rd_manifest_manager = RawDataManifestManager( etcd, mock_data_source) manifest = mock_rd_manifest_manager.get_manifest(0) if not manifest.finished: metas = [] for fpath in input_fpaths: metas.append(dj_pb.RawDataMeta(file_path=fpath, start_index=-1)) mock_rd_manifest_manager.add_raw_data(0, metas, True) mock_rd_manifest_manager.finish_raw_data(0) super(MockRawDataVisitor, self).__init__(etcd, mock_data_source, 0, raw_data_options)
def publish_raw_data(self, partition_id, fpaths, timestamps=None): if not fpaths: logging.warning("no raw data will be published") return if timestamps is not None and len(fpaths) != len(timestamps): raise RuntimeError("the number of raw data file "\ "and timestamp mismatch") new_raw_data_pubs = [] for index, fpath in enumerate(fpaths): if not gfile.Exists(fpath): raise ValueError('{} is not existed'.format(fpath)) raw_data_pub = dj_pb.RawDatePub(raw_data_meta=dj_pb.RawDataMeta( file_path=fpath, start_index=-1)) if timestamps is not None: raw_data_pub.raw_data_meta.timestamp.MergeFrom( timestamps[index]) new_raw_data_pubs.append(raw_data_pub) next_pub_index = None item_index = 0 data = text_format.MessageToString(new_raw_data_pubs[item_index]) while item_index < len(new_raw_data_pubs): next_pub_index = self._forward_pub_index(partition_id, next_pub_index) if self._check_finish_tag(partition_id, next_pub_index - 1): logging.warning("partition %d has been published finish tag "\ "at index %d", partition_id, next_pub_index-1) break kvstore_key = common.raw_data_pub_kvstore_key( self._raw_data_pub_dir, partition_id, next_pub_index) if self._kvstore.cas(kvstore_key, None, data): logging.info("Success publish %s at index %d for partition"\ "%d", data, next_pub_index, partition_id) next_pub_index += 1 item_index += 1 if item_index < len(new_raw_data_pubs): raw_data_pub = new_raw_data_pubs[item_index] data = text_format.MessageToString(raw_data_pub) if item_index < len(new_raw_data_pubs) - 1: logging.warning("%d files are not published since meet finish "\ "tag for partition %d. list following", len(new_raw_data_pubs) - item_index, partition_id) for idx, pub in enumerate(new_raw_data_pubs[item_index:]): logging.warning("%d. %s", idx, pub.raw_data_meta.file_path)
def _process_next_process_index(self, partition_id, manifest): assert manifest is not None and manifest.partition_id == partition_id next_process_index = manifest.next_process_index while True: meta_etcd_key = \ common.raw_data_meta_etcd_key( self._data_source.data_source_meta.name, partition_id, next_process_index ) data = self._etcd.get_data(meta_etcd_key) if data is None: break meta = text_format.Parse(data, dj_pb.RawDataMeta()) self._existed_fpath[meta.file_path] = \ (partition_id, next_process_index) self._update_raw_data_latest_timestamp(partition_id, meta.timestamp) next_process_index += 1 if next_process_index != manifest.next_process_index: manifest.next_process_index = next_process_index self._update_manifest(manifest) else: self._local_manifest[partition_id] = manifest
def generate_raw_data(self, begin_index, item_count): raw_data_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if not gfile.Exists(raw_data_dir): gfile.MakeDirs(raw_data_dir) self.total_raw_data_count += item_count useless_index = 0 rdm = raw_data_visitor.RawDataManager(self.kvstore, self.data_source, 0) fpaths = [] for block_index in range(0, item_count // 2048): builder = DataBlockBuilder( self.raw_data_dir, self.data_source.data_source_meta.name, 0, block_index, dj_pb.WriterOptions(output_writer='TF_RECORD'), None) cands = list( range(begin_index + block_index * 2048, begin_index + (block_index + 1) * 2048)) start_index = cands[0] for i in range(len(cands)): if random.randint(1, 4) > 2: continue a = random.randint(i - 32, i + 32) b = random.randint(i - 32, i + 32) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if (abs(cands[a] - i - start_index) <= 32 and abs(cands[b] - i - start_index) <= 32): cands[a], cands[b] = cands[b], cands[a] for example_idx in cands: feat = {} example_id = '{}'.format(example_idx).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) event_time = 150000000 + example_idx feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) label = random.choice([1, 0]) if random.random() < 0.8: feat['label'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[label])) example = tf.train.Example(features=tf.train.Features( feature=feat)) builder.append_item(TfExampleItem(example.SerializeToString()), useless_index, useless_index) useless_index += 1 meta = builder.finish_data_block() fname = common.encode_data_block_fname( self.data_source.data_source_meta.name, meta) fpath = os.path.join(raw_data_dir, fname) fpaths.append( dj_pb.RawDataMeta( file_path=fpath, timestamp=timestamp_pb2.Timestamp(seconds=3))) self.g_data_block_index += 1 all_files = [ os.path.join(raw_data_dir, f) for f in gfile.ListDirectory(raw_data_dir) if not gfile.IsDirectory(os.path.join(raw_data_dir, f)) ] for fpath in all_files: if not fpath.endswith(common.DataBlockSuffix): gfile.Remove(fpath) self.manifest_manager.add_raw_data(0, fpaths, False)
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
def test_raw_data_manifest_manager(self): cli = mysql_client.DBClient('test_cluster', 'localhost:2379', 'test_user', 'test_password', 'fedlearner', True) partition_num = 4 rank_id = 2 data_source = common_pb.DataSource() data_source.data_source_meta.name = "milestone-x" data_source.data_source_meta.partition_num = partition_num data_source.role = common_pb.FLRole.Leader cli.delete_prefix( common.data_source_kvstore_base_dir( data_source.data_source_meta.name)) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( cli, data_source) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) manifest = manifest_manager.alloc_sync_exampld_id(rank_id) self.assertNotEqual(manifest, None) partition_id = manifest.partition_id manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i != partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) partition_id2 = 3 - partition_id rank_id2 = 100 manifest = manifest_manager.alloc_join_example(rank_id2, partition_id2) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_join_example, rank_id2, partition_id2) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, -rank_id, partition_id) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id2, partition_id2) rank_id3 = 0 manifest = manifest_manager.alloc_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) self.assertFalse(manifest_map[i].finished) self.assertRaises(Exception, manifest_manager.finish_sync_example_id, rank_id, partition_id) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)) ] self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, raw_data_metas, False) manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 3) self.assertEqual(latest_ts.nanos, 0) manifest = manifest_manager.get_manifest(partition_id) self.assertEqual(manifest.next_process_index, 2) raw_data_metas = [ dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)), dj_pb.RawDataMeta(file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)), dj_pb.RawDataMeta(file_path='c', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta(file_path='d', timestamp=timestamp_pb2.Timestamp(seconds=4)) ] manifest_manager.add_raw_data(partition_id, raw_data_metas, True) latest_ts = manifest_manager.get_raw_date_latest_timestamp( partition_id) self.assertEqual(latest_ts.seconds, 4) self.assertEqual(latest_ts.nanos, 0) manifest_map = manifest_manager.list_all_manifest() for i in range(partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].next_process_index, 4) else: self.assertEqual(manifest_map[i].next_process_index, 0) manifest_manager.finish_raw_data(partition_id) manifest_manager.finish_raw_data(partition_id) self.assertRaises(Exception, manifest_manager.add_raw_data, partition_id, 200) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_manager.finish_sync_example_id(rank_id, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) self.assertTrue(manifest_map[i].finished) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_manager.finish_join_example(rank_id3, partition_id) manifest_map = manifest_manager.list_all_manifest() for i in range(data_source.data_source_meta.partition_num): self.assertTrue(i in manifest_map) if i == partition_id: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.Synced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, rank_id) else: self.assertEqual(manifest_map[i].sync_example_id_rep.state, dj_pb.SyncExampleIdState.UnSynced) self.assertEqual(manifest_map[i].sync_example_id_rep.rank_id, -1) if i == partition_id: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id3) elif i == partition_id2: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(manifest_map[i].join_example_rep.rank_id, rank_id2) else: self.assertEqual(manifest_map[i].join_example_rep.state, dj_pb.JoinExampleState.UnJoined) self.assertEqual(manifest_map[i].join_example_rep.rank_id, -1) cli.destroy_client_pool()