def _wait_timestamp(self, target_l, target_f): while True: min_datetime_l = None min_datetime_f = None for pid in range( self._data_source_f.data_source_meta.partition_num): req_l = dj_pb.RawDataRequest( partition_id=pid, data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.RawDataRequest( partition_id=pid, data_source_meta=self._data_source_f.data_source_meta) rsp_l = self._master_client_l.GetRawDataLatestTimeStamp(req_l) rsp_f = self._master_client_f.GetRawDataLatestTimeStamp(req_f) datetime_l = common.convert_timestamp_to_datetime( rsp_l.timestamp) datetime_f = common.convert_timestamp_to_datetime( rsp_f.timestamp) if min_datetime_l is None or min_datetime_l > datetime_l: min_datetime_l = datetime_l if min_datetime_f is None or min_datetime_f > datetime_f: min_datetime_f = datetime_f if min_datetime_l >= target_l and min_datetime_f >= target_f: break else: time.sleep(2)
def _make_raw_data_request(self, partition_id): if self._data_source.role == common_pb.FLRole.Leader: return dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=partition_id, join_example=empty_pb2.Empty()) assert self._data_source.role == common_pb.FLRole.Follower, \ "if not Leader, otherwise, must be Follower" return dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=partition_id, sync_example_id=empty_pb2.Empty())
def _allocate_sync_partition_fn(self): assert self._processing_manifest is None req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rsp = self._master_client.RequestJoinPartition(req) if rsp.status.code != 0: raise RuntimeError("Failed to Request partition for sync "\ "id to follower, error msg {}".format( rsp.status.error_message)) if rsp.HasField('finished'): with self._lock: self._state = None return if not rsp.HasField('manifest'): logging.warning("no manifest is at state %d, wait and retry", dj_pb.RawDataState.UnAllocated) return rdv = RawDataVisitor(self._etcd, self._data_source, rsp.manifest.partition_id, self._options) with self._lock: self._processing_manifest = rsp.manifest self._raw_data_visitor = rdv self._check_manifest() self._wakeup_follower_example_id_syncer()
def _make_raw_data_request(self): return dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=-1, join_example=empty_pb2.Empty() )
def _allocate_join_partition_fn(self): assert self._processing_manifest is None req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rsp = self._master_client.RequestJoinPartition(req) if rsp.status.code != 0: raise RuntimeError("Failed to Request partition for "\ "example intsesection, error msg {}".format( rsp.status.error_message)) if rsp.HasField('finished'): with self._lock: self._state = None return if not rsp.HasField('manifest'): logging.warning("no manifest is at state %d, wait and retry", dj_pb.RawDataState.Synced) return joiner = create_example_joiner(self._etcd, self._data_source, rsp.manifest.partition_id) with self._lock: self._processing_manifest = rsp.manifest self._joiner = joiner self._check_manifest() self._wakeup_leader_example_joiner() self._wakeup_data_block_meta_syncer()
def finish_raw_data(self, partition_id): self._check_partition_id(partition_id) request = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id, finish_raw_data=empty_pb2.Empty()) return self._master_client.FinishRawData(request)
def _make_finish_raw_data_request(self, impl_ctx): return dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=impl_ctx.partition_id, join_example=empty_pb2.Empty() )
def add_raw_data(self, partition_id, fpaths, dedup, timestamps=None): self._check_partition_id(partition_id) if not fpaths: raise RuntimeError("no files input") if timestamps is not None and len(fpaths) != len(timestamps): raise RuntimeError("the number of raw data file "\ "and timestamp mismatch") rdreq = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=dedup ) ) for index, fpath in enumerate(fpaths): if not gfile.Exists(fpath): raise ValueError('{} is not existed' % format(fpath)) raw_data_meta = dj_pb.RawDataMeta( file_path=fpath, start_index=-1 ) if timestamps is not None: raw_data_meta.timestamp.MergeFrom(timestamps[index]) rdreq.added_raw_data_metas.raw_data_metas.append(raw_data_meta) return self._master_client.AddRawData(rdreq)
def _sniff_raw_data_finished(self, impl_ctx): assert isinstance(impl_ctx, ExampleIdSyncLeader.ImplContext) req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=impl_ctx.partition_id) manifest = self._master_client.QueryRawDataManifest(req) if manifest.finished: impl_ctx.set_raw_data_finished()
def test_all_assembly(self): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( self.etcd_l, self.raw_data_controller_l, self.data_source_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}') self.generate_raw_data( self.etcd_f, self.raw_data_controller_f, self.data_source_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}') worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self.master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, self.worker_options) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self.master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, self.worker_options) worker_l.start() worker_f.start() for i in range(self.data_source_f.data_source_meta.partition_num): rdmreq = dj_pb.RawDataRequest( data_source_meta=self.data_source_l.data_source_meta, partition_id=i, finish_raw_data=empty_pb2.Empty()) rsp = self.master_client_l.FinishRawData(rdmreq) self.assertEqual(rsp.code, 0) rsp = self.master_client_f.FinishRawData(rdmreq) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta) dss_l = self.master_client_l.GetDataSourceStatus(req_l) dss_f = self.master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) worker_l.stop() worker_f.stop() self.master_l.stop() self.master_f.stop()
def get_raw_data_latest_timestamp(self, partition_id): self._check_partition_id(partition_id) request = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id) response = self._master_client.GetRawDataLatestTimeStamp(request) if response.status.code != 0: raise RuntimeError("Failed to call GetRawDataLatestTimeStamp "\ "for partition {} of data source {}. reason {}"\ .format(partition_id, self._data_source.data_source_meta.name, response.status.error_message)) return response.timestamp
def add_raw_data(self, partition_id, fpaths, dedup): self._check_partition_id(partition_id) if not fpaths: raise RuntimeError("no files input") for fpath in fpaths: if not gfile.Exists(fpath): raise ValueError('{} is not existed' % format(fpath)) rdreq = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id, raw_data_fpaths=dj_pb.RawDataFilePaths(file_paths=fpaths, dedup=dedup)) return self._master_client.AddRawData(rdreq)
def _sniff_join_data_finished(self, impl_ctx): assert isinstance(impl_ctx, ExampleJoinLeader.ImplContext) if not impl_ctx.is_sync_example_id_finished() or \ not impl_ctx.is_raw_data_finished(): req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=impl_ctx.partition_id) manifest = self._master_client.QueryRawDataManifest(req) if manifest.sync_example_id_rep.state == \ dj_pb.SyncExampleIdState.Synced: impl_ctx.set_sync_example_id_finished() if manifest.finished: impl_ctx.set_raw_data_finished()
def _is_partition_finished(self, partition_id): request = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=partition_id) manifest = self._master_client.QueryRawDataManifest(request) assert manifest is not None and \ manifest.partition_id == partition_id if self._data_source.role == common_pb.FLRole.Leader: return manifest.join_example_rep.state == \ dj_pb.JoinExampleState.Joined assert self._data_source.role == common_pb.FLRole.Follower, \ "if not Leader, otherwise, must be Follower" return manifest.sync_example_id_rep.state == \ dj_pb.SyncExampleIdState.Synced
def _update_peer_index(self, impl_ctx, peer_next_index, peer_dumped_index): assert isinstance(impl_ctx, TransmitLeader.ImplContext) _, dumped_index = impl_ctx.get_peer_index() impl_ctx.set_peer_index(peer_next_index, peer_dumped_index) if dumped_index < peer_dumped_index: req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, partition_id=impl_ctx.partition_id, peer_dumped_index=dj_pb.PeerDumpedIndex( peer_dumped_index=peer_dumped_index)) rsp = self._master_client.ForwardPeerDumpedIndex(req) if rsp.code != 0: raise RuntimeError("{} failed to forward peer dumped index "\ "to {} reason: {}".format(self._repr_str, peer_dumped_index, rsp.error_message))
def StartPartition(self, request, context): response = dj_pb.LeaderStartPartitionResponse() if not self._validate_data_source_meta( request.data_source_meta, self._data_source.data_source_meta): response.status.code = -1 response.status.error_message = "data source meta mismtach" return response if request.partition_id < 0: response.status.code = -2 response.status.error_message = ( "partition id {} illegal".format(request.partition_id) ) return response manifest = self._query_raw_data_manifest(request.partition_id) if manifest.state > dj_pb.RawDataState.Joining: response.finished = True return response rdr_req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, join_example=dj_pb.JoinExampleRequest( partition_id=request.partition_id ) ) rdr_rsp = self._master_client.RequestJoinPartition(rdr_req) if rdr_rsp.status.code != 0: response.status.MergeFrom(rdr_rsp.status) return response if not rdr_rsp.HasField("manifest"): raise RuntimeError( "unknow field for master raw data request response" ) assert rdr_rsp.manifest.state == dj_pb.RawDataState.Joining join_follower = self._example_join_follower next_index = join_follower.start_create_data_block( request.partition_id ) response.finished = False response.next_data_block_index = next_index return response
def test_api(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 32 data_source_meta.max_matching_window = 1024 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertFalse(rdrsp.HasField('manifest')) self.assertFalse(rdrsp.HasField('finished')) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(frreq) self.assertEqual(frrsp.code, 0) rdrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) master_l.stop() master_f.stop()
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()