def WorkerRegister(self, request, context): with self._lock: # for compatibility, more information see: # protocal/fedlearner/common/trainer_master_service.proto if self._worker0_cluster_def is None and request.worker_rank == 0: self._worker0_cluster_def = request.cluster_def if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED, tm_pb.MasterStatus.COMPLETED): return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_DATA_FINISHED)) if self._status != tm_pb.MasterStatus.RUNNING: return tm_pb.WorkerRegisterResponse( status=common_pb.Status( code=common_pb.StatusCode. \ STATUS_WAIT_FOR_SYNCING_CHECKPOINT )) if request.worker_rank in self._running_workers: fl_logging.warning("worker_%d:%s repeat registration", request.worker_rank, request.hostname) else: fl_logging.info("worker_%d:%s registration", request.worker_rank, request.hostname) self._running_workers.add(request.worker_rank) if request.worker_rank in self._completed_workers: self._completed_workers.remove(request.worker_rank) return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS))
def _check_data_source_meta(self, remote_meta, raise_exp=False): if self._data_source_meta != remote_meta: local_meta = self._data_source_meta if local_meta.name != remote_meta.name: logging.error("data_source_meta mismtach since name "\ "%s != %s", local_meta.name, remote_meta.name) if local_meta.partition_num != remote_meta.partition_num: logging.error("data_source_meta mismatch since partition "\ "num %d != %d", local_meta.partition_num, remote_meta.partition_num) if local_meta.start_time != remote_meta.start_time: logging.error("data_source_meta mismatch since start_time "\ "%d != %d", local_meta.start_time, remote_meta.start_time) if local_meta.end_time != remote_meta.end_time: logging.error("data_source_meta mismatch since end_time "\ "%d != %d", local_meta.end_time, remote_meta.end_time) if local_meta.negative_sampling_rate != \ remote_meta.negative_sampling_rate: logging.error("data_source_meta mismatch since negative_"\ "sampling_rate %f != %f", local_meta.negative_sampling_rate, remote_meta.negative_sampling_rate) if raise_exp: raise RuntimeError("data source meta mismatch") return common_pb.Status(code=-1, error_message="data source meta mismatch") return common_pb.Status(code=0)
def _data_block_handler(self, request): assert self._connected, "Cannot load data before connect" if not self._data_block_handler_fn: raise RuntimeError("Received DataBlockMessage but" \ " no handler registered") if self._data_block_handler_fn(request): return common_pb.Status(code=common_pb.STATUS_SUCCESS) return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
def _psi_sign_fn(self, request): d, n = self._prv_key.d, self._prv_key.n if self._process_pool_executor is not None: rids = [rid for rid in request.ids] # pylint: disable=unnecessary-comprehension sign_future = self._process_pool_executor.submit( RsaPsiSigner._psi_sign_impl, rids, d, n) return dj_pb.SignIdsResponse(status=common_pb.Status(code=0), signed_ids=sign_future.result()) return dj_pb.SignIdsResponse(status=common_pb.Status(code=0), signed_ids=RsaPsiSigner._psi_sign_impl( request.ids, d, n))
def _data_block_handler(self, request): assert self._data_block_handler_fn is not None, \ "[Bridge] receive DataBlockMessage but no handler registered." if self._data_block_handler_fn(request): logging.info('[Bridge] succeeded to load data block %s', request.block_id) return tws2_pb.LoadDataBlockResponse(status=common_pb.Status( code=common_pb.STATUS_SUCCESS)) logging.info('[Bridge] failed to load data block %s', request.block_id) return tws2_pb.LoadDataBlockResponse(status=common_pb.Status( code=common_pb.STATUS_INVALID_DATA_BLOCK))
def RequestDataBlock(self, request, context): if request.worker_rank not in self._running_workers: return tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_REQUEST, error_message="unregistered worker")) if request.worker_rank in self._completed_workers: return tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_REQUEST, error_message="worker has completed")) return self._request_data_block(request)
def _data_block_handler(self, request): assert self._connected, "Cannot load data before connect" if not self._data_block_handler_fn: raise RuntimeError("Received DataBlockMessage but" \ " no handler registered") metrics.emit_counter('load_data_block_counter', 1) if self._data_block_handler_fn(request): logging.info('Succeeded to load data block %s', request.block_id) return common_pb.Status(code=common_pb.STATUS_SUCCESS) metrics.emit_counter('load_data_block_fail_counter', 1) logging.info('Failed to load data block %s', request.block_id) return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
def _transmit_handler(self, request): assert self._connected, "Cannot transmit before connect" if request.seq_num >= self._next_receive_seq_num: assert request.seq_num == self._next_receive_seq_num, \ "Invalid request" self._next_receive_seq_num += 1 if request.HasField('start'): with self._condition: self._received_data[request.start.iter_id] = {} elif request.HasField('commit'): pass elif request.HasField('data'): with self._condition: assert request.data.iter_id in self._received_data self._received_data[ request.data.iter_id][ request.data.name] = \ tf.make_ndarray(request.data.tensor) self._condition.notifyAll() elif request.HasField('prefetch'): for func in self._prefetch_handlers: func(request.prefetch) else: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_INVALID_REQUEST), next_seq_num=self._next_receive_seq_num) return tws_pb.TrainerWorkerResponse( next_seq_num=self._next_receive_seq_num)
def _transmit_handler(self, request): assert self._connected, "Cannot transmit before connect" with self._transmit_receive_lock: if request.HasField('keepalive'): # keep alive message, do nothing return tws_pb.TrainerWorkerResponse(next_seq_num=-1) logging.debug("Received message seq_num=%d." " Wanted seq_num=%d.", request.seq_num, self._next_receive_seq_num) if request.seq_num > self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_MISSING), next_seq_num=self._next_receive_seq_num) if request.seq_num < self._next_receive_seq_num: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_MESSAGE_DUPLICATED), next_seq_num=self._next_receive_seq_num) # request.seq_num == self._next_receive_seq_num self._next_receive_seq_num += 1 if request.HasField('start'): with self._condition: self._received_data[request.start.iter_id] = {} elif request.HasField('commit'): pass elif request.HasField('data'): with self._condition: assert request.data.iter_id in self._received_data self._received_data[ request.data.iter_id][ request.data.name] = request.data self._condition.notifyAll() elif request.HasField('prefetch'): for func in self._prefetch_handlers: func(request.prefetch) else: return tws_pb.TrainerWorkerResponse( status=common_pb.Status( code=common_pb.STATUS_INVALID_REQUEST), next_seq_num=self._next_receive_seq_num) return tws_pb.TrainerWorkerResponse( next_seq_num=self._next_receive_seq_num)
def _request_data_block(self, request): data_block = self._data_visitor.get_datablock_by_id(request.block_id) if data_block: fl_logging.info("allocated worker_%d with block: %s", request.worker_rank, data_block.id) response = tm_pb.DataBlockResponse( status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS), block_id=data_block.id, data_path=data_block.data_path, ) else: fl_logging.error("invalid data block id: %s", request.block_id) response = tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK, error_message="invalid data block")) return response
def _psi_sign_fn(self, request): d, n = self._rsa_private_key.d, self._rsa_private_key.n start_tm = time.time() response = None if self._process_pool_executor is not None: rids = [rid for rid in request.ids] # pylint: disable=unnecessary-comprehension sign_future = self._process_pool_executor.submit( RsaPsiSigner._psi_sign_impl, rids, d, n) response = dj_pb.SignIdsResponse(status=common_pb.Status(code=0), signed_ids=sign_future.result()) else: response = dj_pb.SignIdsResponse( status=common_pb.Status(code=0), signed_ids=RsaPsiSigner._psi_sign_impl(request.ids, d, n)) self._record_sign_duration(request.begin_index, len(request.ids), time.time() - start_tm) return response
def WorkerComplete(self, request, context): with self._lock: if request.worker_rank not in self._running_workers: return tm_pb.WorkerRegisterResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_INVALID_REQUEST, error_message="unregistered worker")) fl_logging.info("worker_%d completed", request.worker_rank) self._completed_workers.add(request.worker_rank) if request.worker_rank == 0: self._worker0_terminated_at = request.timestamp if len(self._running_workers) == len(self._completed_workers) \ and 0 in self._running_workers: # worker 0 completed and all datablock has finished self._transfer_status(tm_pb.MasterStatus.RUNNING, tm_pb.MasterStatus.WORKER_COMPLETED) return tm_pb.WorkerCompleteResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS))
def AbortDataSource(self, request, context): response = common_pb.Status() data_source = self._fsm.get_data_source() if not self._validate_data_source_meta(request, data_source.data_source_meta): response.code = -1 response.error_message = 'data source meta mismtach' elif not self._fsm.set_failed(): response.code = -2 response.error_message = "failed to set failed state to fsm" else: response.code = 0 return response
def _request_data_block(self, request): try: data_block = next(self._data_visitor) except StopIteration: data_block = None response = tm_pb.DataBlockResponse() if data_block: fl_logging.info("allocated worker_%d with block: %s", request.worker_rank, data_block.id) response = tm_pb.DataBlockResponse( status=common_pb.Status( code=common_pb.StatusCode.STATUS_SUCCESS), block_id=data_block.id, data_path=data_block.data_path, ) else: response = tm_pb.DataBlockResponse(status=common_pb.Status( code=common_pb.StatusCode.STATUS_DATA_FINISHED, error_message="data block finished")) return response
def receive_job(request): logging.debug("In Platform Scheduler::_receive_job application_id = %s", request.application_id) response = common_pb.Status() try: model_meta = ModelMeta.get(ModelMeta.name == request.model_uri) except Exception: #pylint: disable=W0703 response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR response.error_message = 'model_uri [%s] was not authorized.' \ % request.model_uri return response try: model_version = ModelVersion.get( (ModelVersion.commit == request.model_commit) and (ModelVersion.model_meta_id == model_meta.id)) except Exception: #pylint: disable=W0703 response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR response.error_message = 'model_uri [%s] model_version [%s] ' \ 'was not authorized.' % \ (request.model_uri, request.model_commit) return response try: data_source = DataSourceMeta.get( DataSourceMeta.name == request.data_meta.data_source_name) except Exception: #pylint: disable=W0703 response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR response.error_message = 'data_source [%s] was not authorized.' \ % (request.data_meta.data_source_name) return response job = Job.create(name=request.name, description=request.description, role='Follower', application_id=request.application_id, status=JOBSTATUS.SUBMMITTED.value, model_version_id=model_version.id, serving_version=request.serving_version, data_source_id=data_source.id, cluster_spec=json.dumps( {'worker_replicas': request.pair_num}), group_list=json.dumps([]), create_time=datetime.datetime.now()) if not job: response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR response.error_message = 'job [%s] create failed.' % request.name return response response.code = common_pb.StatusCode.STATUS_SUCCESS return response
def SyncExamples(self, request, context): response = common_pb.Status() response.code = 0 if not self._validate_data_source_meta( request.data_source_meta, self._data_source.data_source_meta): response.code = -1 response.error_message = "data source meta mismtach" return response sync_follower = self._example_id_sync_follower filled, next_index = sync_follower.add_synced_example_req(request) if not filled: response.code = -1 response.error_message = ( "the follower required {}".format(next_index)) return response
def CreateDataBlock(self, request, context): response = common_pb.Status() response.code = 0 if not self._validate_data_source_meta( request.data_source_meta, self._data_source.data_source_meta): response.code = -1 response.error_message = "data source meta mismtach" return response join_follower = self._example_join_follower filled, next_index = join_follower.add_synced_data_block_meta( request.data_block_meta ) if not filled: response.code = -1 response.error_message = ( "the follower required {}".format(next_index) ) return response
def add_raw_data(self, partition_id, fpaths, dedup, timestamps=None): self._check_partition_id(partition_id) if not fpaths: logging.warning("no raw data will be added") return common_pb.Status(code=0) if timestamps is not None and len(fpaths) != len(timestamps): raise RuntimeError("the number of raw data file "\ "and timestamp mismatch") rdreq = dj_pb.RawDataRequest( data_source_meta=self._data_source.data_source_meta, partition_id=partition_id, added_raw_data_metas=dj_pb.AddedRawDataMetas(dedup=dedup)) for index, fpath in enumerate(fpaths): if not gfile.Exists(fpath): raise ValueError('{} is not existed' % format(fpath)) raw_data_meta = dj_pb.RawDataMeta(file_path=fpath, start_index=-1) if timestamps is not None: raw_data_meta.timestamp.MergeFrom(timestamps[index]) rdreq.added_raw_data_metas.raw_data_metas.append(raw_data_meta) return self._master_client.AddRawData(rdreq)
def FinishJoinPartition(self, request, context): response = common_pb.Status() response.code = 0 data_source = self._fsm.get_data_source() if not self._validate_data_source_meta(request.data_source_meta, data_source.data_source_meta): response.code = -1 response.error_message = 'data source meta mismtach' return response if data_source.state != common_pb.DataSourceState.Processing: response.code = -2 response.error_message = "data source is not at processing state" return response rank_id = request.rank_id manifest_manager = self._fsm.get_mainifest_manager() if request.HasField('sync_example_id'): partition_id = request.sync_example_id.partition_id manifest_manager.finish_sync_partition(rank_id, partition_id) elif request.HasField('join_example'): partition_id = request.join_example.partition_id manifest_manager.finish_join_partition(rank_id, partition_id) else: raise RuntimeError("unknown request type") return response
def _bye_fn(self): with self._cond: self._peer_bye = True self._cond.notify_all() return common_pb.Status(code=0)
def _data_block_handler(self, request): if not self._data_block_handler_fn: raise RuntimeError("Received DataBlockMessage but" \ " no handler registered") self._data_block_handler_fn(request) return common_pb.Status(code=common_pb.STATUS_SUCCESS)