示例#1
0
 def _allocate_join_partition_fn(self):
     assert self._processing_manifest is None
     req = dj_pb.RawDataRequest(
         data_source_meta=self._data_source.data_source_meta,
         rank_id=self._rank_id,
         join_example=dj_pb.JoinExampleRequest(partition_id=-1))
     rsp = self._master_client.RequestJoinPartition(req)
     if rsp.status.code != 0:
         raise RuntimeError("Failed to Request partition for "\
                            "example intsesection, error msg {}".format(
                             rsp.status.error_message))
     if rsp.HasField('finished'):
         with self._lock:
             self._state = None
             return
     if not rsp.HasField('manifest'):
         logging.warning("no manifest is at state %d, wait and retry",
                         dj_pb.RawDataState.Synced)
         return
     joiner = create_example_joiner(self._etcd, self._data_source,
                                    rsp.manifest.partition_id)
     with self._lock:
         self._processing_manifest = rsp.manifest
         self._joiner = joiner
         self._check_manifest()
         self._wakeup_leader_example_joiner()
         self._wakeup_data_block_meta_syncer()
示例#2
0
    def _finish_join_partition(self, follower_finished):
        if not follower_finished:
            req = dj_pb.LeaderFinishPartitionRequest(
                data_source_meta=self._data_source.data_source_meta,
                partition_id=self._processing_manifest.partition_id,
            )
            rsp = self._peer_client.FinishPartition(req)
            if rsp.status.code != 0:
                raise RuntimeError(
                        "Failed to call Leader finish partition "\
                        "reason: {}".format(rsp.status.error_message)
                )
            follower_finished = rsp.finished

        if follower_finished:
            req = dj_pb.FinishRawDataRequest(
                data_source_meta=self._data_source.data_source_meta,
                rank_id=self._rank_id,
                join_example=dj_pb.JoinExampleRequest(
                    partition_id=self._processing_manifest.partition_id, ))
            rsp = self._master_client.FinishJoinPartition(req)
            if rsp.code != 0:
                raise RuntimeError(
                        "Failed to finish raw data from joining. "\
                        "reason: %s" % rsp.error_message
                    )
            return True
        logging.warning("Leader is creating data block, waitiing")
        return False
示例#3
0
    def StartPartition(self, request, context):
        response = dj_pb.LeaderStartPartitionResponse()
        if not self._validate_data_source_meta(
                request.data_source_meta, self._data_source.data_source_meta):
            response.status.code = -1
            response.status.error_message = "data source meta mismtach"
            return response

        if request.partition_id < 0:
            response.status.code = -2
            response.status.error_message = (
                    "partition id {} illegal".format(request.partition_id)
                )
            return response

        manifest = self._query_raw_data_manifest(request.partition_id)
        if manifest.state > dj_pb.RawDataState.Joining:
            response.finished = True
            return response

        rdr_req = dj_pb.RawDataRequest(
                data_source_meta=self._data_source.data_source_meta,
                rank_id=self._rank_id,
                join_example=dj_pb.JoinExampleRequest(
                    partition_id=request.partition_id
                )
            )
        rdr_rsp = self._master_client.RequestJoinPartition(rdr_req)
        if rdr_rsp.status.code != 0:
            response.status.MergeFrom(rdr_rsp.status)
            return response
        if not rdr_rsp.HasField("manifest"):
            raise RuntimeError(
                    "unknow field for master raw data request response"
                )
        assert rdr_rsp.manifest.state == dj_pb.RawDataState.Joining

        join_follower = self._example_join_follower
        next_index = join_follower.start_create_data_block(
                request.partition_id
            )
        response.finished = False
        response.next_data_block_index = next_index
        return response
示例#4
0
    def FinishPartition(self, request, context):
        response = dj_pb.LeaderFinishPartitionResponse()
        response.status.code = 0
        response.finished = False
        if not self._validate_data_source_meta(
                request.data_source_meta, self._data_source.data_source_meta):
            response.status.code = -1
            response.status.error_message = "data source meta mismtach"
            return response

        join_follower = self._example_join_follower
        if (join_follower.get_processing_partition_id() ==
                request.partition_id):
            finished = join_follower.finish_sync_data_block_meta(
                    request.partition_id
                )
            if finished:
                req = dj_pb.FinishRawDataRequest(
                        data_source_meta=self._data_source.data_source_meta,
                        rank_id=self._rank_id,
                        join_example=dj_pb.JoinExampleRequest(
                            partition_id=request.partition_id
                        )
                    )
                rsp = self._master_client.FinishJoinPartition(req)
                response.status.MergeFrom(rsp)
                if rsp.code == 0:
                    join_follower.reset_dump_partition()
            response.finished = finished
        else:
            manifest = self._query_raw_data_manifest(request.partition_id)
            if manifest.state > dj_pb.RawDataState.Joining:
                response.finished = True
            else:
                response.status.code = -2
                response.status.finished = False
                response.status.error_message = (
                        "partition {} at state {} but it is not "
                        "processing".format(request.partition_id,
                                            manifest.state)
                    )
        return response
示例#5
0
    def test_api(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f)
        etcd_l.delete_prefix(data_source_name)
        etcd_f.delete_prefix(data_source_name)
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_meta.min_matching_window = 32
        data_source_meta.max_matching_window = 1024
        data_source_meta.data_source_type = common_pb.DataSourceType.Sequential
        data_source_meta.max_example_in_data_block = 1000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        etcd_l.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_l))
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        etcd_f.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_f))

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Processing
                    and rsp_f.state == common_pb.DataSourceState.Processing):
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertFalse(rdrsp.HasField('manifest'))
        self.assertFalse(rdrsp.HasField('finished'))

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(frreq)
        self.assertEqual(frrsp.code, 0)
        rdrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Finished
                    and rsp_f.state == common_pb.DataSourceState.Finished):
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()