def _allocate_sync_partition_fn(self): assert self._processing_manifest is None req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rsp = self._master_client.RequestJoinPartition(req) if rsp.status.code != 0: raise RuntimeError("Failed to Request partition for sync "\ "id to follower, error msg {}".format( rsp.status.error_message)) if rsp.HasField('finished'): with self._lock: self._state = None return if not rsp.HasField('manifest'): logging.warning("no manifest is at state %d, wait and retry", dj_pb.RawDataState.UnAllocated) return rdv = RawDataVisitor(self._etcd, self._data_source, rsp.manifest.partition_id, self._options) with self._lock: self._processing_manifest = rsp.manifest self._raw_data_visitor = rdv self._check_manifest() self._wakeup_follower_example_id_syncer()
def _finish_sync_partition(self, follower_finished): if not follower_finished: req = dj_pb.FollowerFinishPartitionRequest( data_source_meta=self._data_source.data_source_meta, partition_id=self._processing_manifest.partition_id, ) rsp = self._peer_client.FinishPartition(req) if rsp.status.code != 0: raise RuntimeError( "Failed to call Follower finish partition "\ "reason: {}".format(rsp.status.error_message) ) follower_finished = rsp.finished if follower_finished: req = dj_pb.FinishRawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, sync_example_id=dj_pb.SyncExampleIdRequest( partition_id=self._processing_manifest.partition_id, )) rsp = self._master_client.FinishJoinPartition(req) if rsp.code != 0: raise RuntimeError( "Failed to finish raw data from syncing. "\ "reason: %s" % rsp.error_message ) return True logging.warning("Follower is dumping example id, waitiing") return False
def StartPartition(self, request, context): response = dj_pb.FollowerStartPartitionResponse() if not self._validate_data_source_meta( request.data_source_meta, self._data_source.data_source_meta): response.status.code = -1 response.status.error_message = "data source meta mismtach" return response if request.partition_id < 0: response.status.code = -2 response.status.error_message = ( "partition id {} illegal".format(request.partition_id) ) return response manifest = self._query_raw_data_manifest(request.partition_id) if manifest.state > dj_pb.RawDataState.Syncing: response.finished = True return response rdr_req = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, sync_example_id=dj_pb.SyncExampleIdRequest( partition_id=request.partition_id ) ) rdr_rsp = self._master_client.RequestJoinPartition(rdr_req) if rdr_rsp.status.code != 0: response.status.MergeFrom(rdr_rsp.status) return response if not rdr_rsp.HasField("manifest"): raise RuntimeError( "unknow field for master raw data request response" ) assert rdr_rsp.manifest.state == dj_pb.RawDataState.Syncing sync_follower = self._example_id_sync_follower next_index = sync_follower.start_dump_partition( request.partition_id ) response.finished = False response.next_index = next_index return response
def FinishPartition(self, request, context): response = dj_pb.FollowerFinishPartitionResponse() response.status.code = 0 response.finished = False if not self._validate_data_source_meta( request.data_source_meta, self._data_source.data_source_meta): response.status.code = -1 response.status.error_message = "data source meta mismtach" return response sync_follower = self._example_id_sync_follower if (sync_follower.get_processing_partition_id() == request.partition_id): finished = sync_follower.finish_sync_partition_example( request.partition_id ) if finished: req = dj_pb.FinishRawDataRequest( data_source_meta=self._data_source.data_source_meta, rank_id=self._rank_id, sync_example_id=dj_pb.SyncExampleIdRequest( partition_id=request.partition_id ) ) rsp = self._master_client.FinishJoinPartition(req) response.status.MergeFrom(rsp) if rsp.code == 0: sync_follower.reset_dump_partition() response.finished = finished else: manifest = self._query_raw_data_manifest(request.partition_id) if manifest.state > dj_pb.RawDataState.Syncing: response.finished = True else: response.status.code = -2 response.status.finished = False response.status.error_message = ( "partition {} at state {} but it is not " "processing".format(request.partition_id, manifest.state) ) return response
def test_api(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 32 data_source_meta.max_matching_window = 1024 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertFalse(rdrsp.HasField('manifest')) self.assertFalse(rdrsp.HasField('finished')) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(frreq) self.assertEqual(frrsp.code, 0) rdrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) master_l.stop() master_f.stop()