Exemplo n.º 1
0
    def WorkerRegister(self, request, context):
        with self._lock:
            # for compatibility, more information see:
            #   protocal/fedlearner/common/trainer_master_service.proto
            if self._worker0_cluster_def is None and request.worker_rank == 0:
                self._worker0_cluster_def = request.cluster_def

            if self._status in (tm_pb.MasterStatus.WORKER_COMPLETED,
                                tm_pb.MasterStatus.COMPLETED):
                return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_DATA_FINISHED))

            if self._status != tm_pb.MasterStatus.RUNNING:
                return tm_pb.WorkerRegisterResponse(
                    status=common_pb.Status(
                        code=common_pb.StatusCode. \
                            STATUS_WAIT_FOR_SYNCING_CHECKPOINT
                    ))

            if request.worker_rank in self._running_workers:
                fl_logging.warning("worker_%d:%s repeat registration",
                                   request.worker_rank, request.hostname)
            else:
                fl_logging.info("worker_%d:%s registration",
                                request.worker_rank, request.hostname)

            self._running_workers.add(request.worker_rank)
            if request.worker_rank in self._completed_workers:
                self._completed_workers.remove(request.worker_rank)
            return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_SUCCESS))
Exemplo n.º 2
0
 def _check_data_source_meta(self, remote_meta, raise_exp=False):
     if self._data_source_meta != remote_meta:
         local_meta = self._data_source_meta
         if local_meta.name != remote_meta.name:
             logging.error("data_source_meta mismtach since name "\
                           "%s != %s", local_meta.name, remote_meta.name)
         if local_meta.partition_num != remote_meta.partition_num:
             logging.error("data_source_meta mismatch since partition "\
                           "num %d != %d", local_meta.partition_num,
                           remote_meta.partition_num)
         if local_meta.start_time != remote_meta.start_time:
             logging.error("data_source_meta mismatch since start_time "\
                           "%d != %d",
                           local_meta.start_time, remote_meta.start_time)
         if local_meta.end_time != remote_meta.end_time:
             logging.error("data_source_meta mismatch since end_time "\
                           "%d != %d",
                           local_meta.end_time, remote_meta.end_time)
         if local_meta.negative_sampling_rate != \
                 remote_meta.negative_sampling_rate:
             logging.error("data_source_meta mismatch since negative_"\
                           "sampling_rate %f != %f",
                           local_meta.negative_sampling_rate,
                           remote_meta.negative_sampling_rate)
         if raise_exp:
             raise RuntimeError("data source meta mismatch")
         return common_pb.Status(code=-1,
                                 error_message="data source meta mismatch")
     return common_pb.Status(code=0)
Exemplo n.º 3
0
 def _data_block_handler(self, request):
     assert self._connected, "Cannot load data before connect"
     if not self._data_block_handler_fn:
         raise RuntimeError("Received DataBlockMessage but" \
                             " no handler registered")
     if self._data_block_handler_fn(request):
         return common_pb.Status(code=common_pb.STATUS_SUCCESS)
     return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
Exemplo n.º 4
0
 def _psi_sign_fn(self, request):
     d, n = self._prv_key.d, self._prv_key.n
     if self._process_pool_executor is not None:
         rids = [rid for rid in request.ids]  # pylint: disable=unnecessary-comprehension
         sign_future = self._process_pool_executor.submit(
             RsaPsiSigner._psi_sign_impl, rids, d, n)
         return dj_pb.SignIdsResponse(status=common_pb.Status(code=0),
                                      signed_ids=sign_future.result())
     return dj_pb.SignIdsResponse(status=common_pb.Status(code=0),
                                  signed_ids=RsaPsiSigner._psi_sign_impl(
                                      request.ids, d, n))
Exemplo n.º 5
0
 def _data_block_handler(self, request):
     assert self._data_block_handler_fn is not None, \
         "[Bridge] receive DataBlockMessage but no handler registered."
     if self._data_block_handler_fn(request):
         logging.info('[Bridge] succeeded to load data block %s',
                      request.block_id)
         return tws2_pb.LoadDataBlockResponse(status=common_pb.Status(
             code=common_pb.STATUS_SUCCESS))
     logging.info('[Bridge] failed to load data block %s', request.block_id)
     return tws2_pb.LoadDataBlockResponse(status=common_pb.Status(
         code=common_pb.STATUS_INVALID_DATA_BLOCK))
Exemplo n.º 6
0
    def RequestDataBlock(self, request, context):
        if request.worker_rank not in self._running_workers:
            return tm_pb.DataBlockResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_INVALID_REQUEST,
                error_message="unregistered worker"))

        if request.worker_rank in self._completed_workers:
            return tm_pb.DataBlockResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_INVALID_REQUEST,
                error_message="worker has completed"))

        return self._request_data_block(request)
Exemplo n.º 7
0
 def _data_block_handler(self, request):
     assert self._connected, "Cannot load data before connect"
     if not self._data_block_handler_fn:
         raise RuntimeError("Received DataBlockMessage but" \
                             " no handler registered")
     metrics.emit_counter('load_data_block_counter', 1)
     if self._data_block_handler_fn(request):
         logging.info('Succeeded to load data block %s', request.block_id)
         return common_pb.Status(code=common_pb.STATUS_SUCCESS)
     metrics.emit_counter('load_data_block_fail_counter', 1)
     logging.info('Failed to load data block %s', request.block_id)
     return common_pb.Status(code=common_pb.STATUS_INVALID_DATA_BLOCK)
Exemplo n.º 8
0
    def _transmit_handler(self, request):
        assert self._connected, "Cannot transmit before connect"
        if request.seq_num >= self._next_receive_seq_num:
            assert request.seq_num == self._next_receive_seq_num, \
                "Invalid request"
            self._next_receive_seq_num += 1

            if request.HasField('start'):
                with self._condition:
                    self._received_data[request.start.iter_id] = {}
            elif request.HasField('commit'):
                pass
            elif request.HasField('data'):
                with self._condition:
                    assert request.data.iter_id in self._received_data
                    self._received_data[
                        request.data.iter_id][
                            request.data.name] = \
                                tf.make_ndarray(request.data.tensor)
                    self._condition.notifyAll()
            elif request.HasField('prefetch'):
                for func in self._prefetch_handlers:
                    func(request.prefetch)
            else:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_INVALID_REQUEST),
                    next_seq_num=self._next_receive_seq_num)

        return tws_pb.TrainerWorkerResponse(
            next_seq_num=self._next_receive_seq_num)
Exemplo n.º 9
0
    def _transmit_handler(self, request):
        assert self._connected, "Cannot transmit before connect"
        with self._transmit_receive_lock:
            if request.HasField('keepalive'):
                # keep alive message, do nothing
                return tws_pb.TrainerWorkerResponse(next_seq_num=-1)

            logging.debug("Received message seq_num=%d."
                          " Wanted seq_num=%d.",
                          request.seq_num, self._next_receive_seq_num)
            if request.seq_num > self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_MISSING),
                    next_seq_num=self._next_receive_seq_num)
            if request.seq_num < self._next_receive_seq_num:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_MESSAGE_DUPLICATED),
                    next_seq_num=self._next_receive_seq_num)

            # request.seq_num == self._next_receive_seq_num
            self._next_receive_seq_num += 1

            if request.HasField('start'):
                with self._condition:
                    self._received_data[request.start.iter_id] = {}
            elif request.HasField('commit'):
                pass
            elif request.HasField('data'):
                with self._condition:
                    assert request.data.iter_id in self._received_data
                    self._received_data[
                        request.data.iter_id][
                            request.data.name] = request.data
                    self._condition.notifyAll()
            elif request.HasField('prefetch'):
                for func in self._prefetch_handlers:
                    func(request.prefetch)
            else:
                return tws_pb.TrainerWorkerResponse(
                    status=common_pb.Status(
                        code=common_pb.STATUS_INVALID_REQUEST),
                    next_seq_num=self._next_receive_seq_num)

            return tws_pb.TrainerWorkerResponse(
                next_seq_num=self._next_receive_seq_num)
Exemplo n.º 10
0
 def _request_data_block(self, request):
     data_block = self._data_visitor.get_datablock_by_id(request.block_id)
     if data_block:
         fl_logging.info("allocated worker_%d with block: %s",
                         request.worker_rank, data_block.id)
         response = tm_pb.DataBlockResponse(
             status=common_pb.Status(
                 code=common_pb.StatusCode.STATUS_SUCCESS),
             block_id=data_block.id,
             data_path=data_block.data_path,
         )
     else:
         fl_logging.error("invalid data block id: %s", request.block_id)
         response = tm_pb.DataBlockResponse(status=common_pb.Status(
             code=common_pb.StatusCode.STATUS_INVALID_DATA_BLOCK,
             error_message="invalid data block"))
     return response
Exemplo n.º 11
0
 def _psi_sign_fn(self, request):
     d, n = self._rsa_private_key.d, self._rsa_private_key.n
     start_tm = time.time()
     response = None
     if self._process_pool_executor is not None:
         rids = [rid for rid in request.ids]  # pylint: disable=unnecessary-comprehension
         sign_future = self._process_pool_executor.submit(
             RsaPsiSigner._psi_sign_impl, rids, d, n)
         response = dj_pb.SignIdsResponse(status=common_pb.Status(code=0),
                                          signed_ids=sign_future.result())
     else:
         response = dj_pb.SignIdsResponse(
             status=common_pb.Status(code=0),
             signed_ids=RsaPsiSigner._psi_sign_impl(request.ids, d, n))
     self._record_sign_duration(request.begin_index, len(request.ids),
                                time.time() - start_tm)
     return response
Exemplo n.º 12
0
    def WorkerComplete(self, request, context):
        with self._lock:
            if request.worker_rank not in self._running_workers:
                return tm_pb.WorkerRegisterResponse(status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_INVALID_REQUEST,
                    error_message="unregistered worker"))
            fl_logging.info("worker_%d completed", request.worker_rank)
            self._completed_workers.add(request.worker_rank)
            if request.worker_rank == 0:
                self._worker0_terminated_at = request.timestamp

            if len(self._running_workers) == len(self._completed_workers) \
                and 0 in self._running_workers:
                # worker 0 completed and all datablock has finished
                self._transfer_status(tm_pb.MasterStatus.RUNNING,
                                      tm_pb.MasterStatus.WORKER_COMPLETED)

            return tm_pb.WorkerCompleteResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_SUCCESS))
Exemplo n.º 13
0
 def AbortDataSource(self, request, context):
     response = common_pb.Status()
     data_source = self._fsm.get_data_source()
     if not self._validate_data_source_meta(request,
                                            data_source.data_source_meta):
         response.code = -1
         response.error_message = 'data source meta mismtach'
     elif not self._fsm.set_failed():
         response.code = -2
         response.error_message = "failed to set failed state to fsm"
     else:
         response.code = 0
     return response
Exemplo n.º 14
0
    def _request_data_block(self, request):
        try:
            data_block = next(self._data_visitor)
        except StopIteration:
            data_block = None

        response = tm_pb.DataBlockResponse()
        if data_block:
            fl_logging.info("allocated worker_%d with block: %s",
                            request.worker_rank, data_block.id)
            response = tm_pb.DataBlockResponse(
                status=common_pb.Status(
                    code=common_pb.StatusCode.STATUS_SUCCESS),
                block_id=data_block.id,
                data_path=data_block.data_path,
            )
        else:
            response = tm_pb.DataBlockResponse(status=common_pb.Status(
                code=common_pb.StatusCode.STATUS_DATA_FINISHED,
                error_message="data block finished"))

        return response
Exemplo n.º 15
0
def receive_job(request):
    logging.debug("In Platform Scheduler::_receive_job  application_id = %s",
                  request.application_id)
    response = common_pb.Status()
    try:
        model_meta = ModelMeta.get(ModelMeta.name == request.model_uri)
    except Exception:  #pylint: disable=W0703
        response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR
        response.error_message = 'model_uri [%s] was not authorized.' \
                                 % request.model_uri
        return response

    try:
        model_version = ModelVersion.get(
            (ModelVersion.commit == request.model_commit)
            and (ModelVersion.model_meta_id == model_meta.id))
    except Exception:  #pylint: disable=W0703
        response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR
        response.error_message = 'model_uri [%s] model_version [%s] ' \
                                 'was not authorized.' % \
                                 (request.model_uri, request.model_commit)
        return response

    try:
        data_source = DataSourceMeta.get(
            DataSourceMeta.name == request.data_meta.data_source_name)
    except Exception:  #pylint: disable=W0703
        response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR
        response.error_message = 'data_source [%s] was not authorized.' \
                                 % (request.data_meta.data_source_name)
        return response

    job = Job.create(name=request.name,
                     description=request.description,
                     role='Follower',
                     application_id=request.application_id,
                     status=JOBSTATUS.SUBMMITTED.value,
                     model_version_id=model_version.id,
                     serving_version=request.serving_version,
                     data_source_id=data_source.id,
                     cluster_spec=json.dumps(
                         {'worker_replicas': request.pair_num}),
                     group_list=json.dumps([]),
                     create_time=datetime.datetime.now())
    if not job:
        response.code = common_pb.StatusCode.STATUS_UNKNOWN_ERROR
        response.error_message = 'job [%s] create failed.' % request.name
        return response
    response.code = common_pb.StatusCode.STATUS_SUCCESS
    return response
Exemplo n.º 16
0
 def SyncExamples(self, request, context):
     response = common_pb.Status()
     response.code = 0
     if not self._validate_data_source_meta(
             request.data_source_meta, self._data_source.data_source_meta):
         response.code = -1
         response.error_message = "data source meta mismtach"
         return response
     sync_follower = self._example_id_sync_follower
     filled, next_index = sync_follower.add_synced_example_req(request)
     if not filled:
         response.code = -1
         response.error_message = (
             "the follower required {}".format(next_index))
     return response
Exemplo n.º 17
0
 def CreateDataBlock(self, request, context):
     response = common_pb.Status()
     response.code = 0
     if not self._validate_data_source_meta(
             request.data_source_meta, self._data_source.data_source_meta):
         response.code = -1
         response.error_message = "data source meta mismtach"
         return response
     join_follower = self._example_join_follower
     filled, next_index = join_follower.add_synced_data_block_meta(
             request.data_block_meta
         )
     if not filled:
         response.code = -1
         response.error_message = (
                 "the follower required {}".format(next_index)
             )
     return response
Exemplo n.º 18
0
 def add_raw_data(self, partition_id, fpaths, dedup, timestamps=None):
     self._check_partition_id(partition_id)
     if not fpaths:
         logging.warning("no raw data will be added")
         return common_pb.Status(code=0)
     if timestamps is not None and len(fpaths) != len(timestamps):
         raise RuntimeError("the number of raw data file "\
                            "and timestamp mismatch")
     rdreq = dj_pb.RawDataRequest(
         data_source_meta=self._data_source.data_source_meta,
         partition_id=partition_id,
         added_raw_data_metas=dj_pb.AddedRawDataMetas(dedup=dedup))
     for index, fpath in enumerate(fpaths):
         if not gfile.Exists(fpath):
             raise ValueError('{} is not existed' % format(fpath))
         raw_data_meta = dj_pb.RawDataMeta(file_path=fpath, start_index=-1)
         if timestamps is not None:
             raw_data_meta.timestamp.MergeFrom(timestamps[index])
         rdreq.added_raw_data_metas.raw_data_metas.append(raw_data_meta)
     return self._master_client.AddRawData(rdreq)
Exemplo n.º 19
0
 def FinishJoinPartition(self, request, context):
     response = common_pb.Status()
     response.code = 0
     data_source = self._fsm.get_data_source()
     if not self._validate_data_source_meta(request.data_source_meta,
                                            data_source.data_source_meta):
         response.code = -1
         response.error_message = 'data source meta mismtach'
         return response
     if data_source.state != common_pb.DataSourceState.Processing:
         response.code = -2
         response.error_message = "data source is not at processing state"
         return response
     rank_id = request.rank_id
     manifest_manager = self._fsm.get_mainifest_manager()
     if request.HasField('sync_example_id'):
         partition_id = request.sync_example_id.partition_id
         manifest_manager.finish_sync_partition(rank_id, partition_id)
     elif request.HasField('join_example'):
         partition_id = request.join_example.partition_id
         manifest_manager.finish_join_partition(rank_id, partition_id)
     else:
         raise RuntimeError("unknown request type")
     return response
Exemplo n.º 20
0
 def _bye_fn(self):
     with self._cond:
         self._peer_bye = True
         self._cond.notify_all()
     return common_pb.Status(code=0)
Exemplo n.º 21
0
 def _data_block_handler(self, request):
     if not self._data_block_handler_fn:
         raise RuntimeError("Received DataBlockMessage but" \
                             " no handler registered")
     self._data_block_handler_fn(request)
     return common_pb.Status(code=common_pb.STATUS_SUCCESS)