def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name, etcd_base_dir, etcd_addrs, options): master_channel = make_insecure_channel( master_addr, ChannelType.INTERNAL, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) self._master_client = dj_grpc.DataJoinMasterServiceStub(master_channel) self._rank_id = rank_id etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, options.use_mock_etcd) data_source = self._sync_data_source() self._data_source_name = data_source.data_source_meta.name self._listen_port = listen_port self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) peer_channel = make_insecure_channel( peer_addr, ChannelType.REMOTE, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) peer_client = dj_grpc.DataJoinWorkerServiceStub(peer_channel) self._data_join_worker = DataJoinWorker(peer_client, self._master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinWorkerServiceServicer_to_server( self._data_join_worker, self._server) self._role_repr = "leader" if data_source.role == \ common_pb.FLRole.Leader else "follower" self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name, etcd_base_dir, etcd_addrs, options): master_channel = make_insecure_channel(master_addr, ChannelType.INTERNAL) master_client = dj_grpc.DataJoinMasterServiceStub(master_channel) etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir) data_source = self.sync_data_source(master_client) self._data_source_name = data_source.data_source_meta.name self._listen_port = listen_port self._rank_id = rank_id self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) peer_channel = make_insecure_channel(peer_addr, ChannelType.REMOTE) if data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" peer_client = dj_grpc.DataJoinFollowerServiceStub(peer_channel) self._diw = data_join_leader.DataJoinLeader( peer_client, master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinLeaderServiceServicer_to_server( self._diw, self._server) else: assert data_source.role == common_pb.FLRole.Follower self._role_repr = "follower" peer_client = dj_grpc.DataJoinLeaderServiceStub(peer_channel) self._diw = data_join_follower.DataJoinFollower( peer_client, master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinFollowerServiceServicer_to_server( self._diw, self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def __init__(self, options, master_addr, rank_id, db_database, db_base_dir, db_addr, db_username, db_password, use_mock_etcd=False): master_channel = make_insecure_channel( master_addr, ChannelType.INTERNAL, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) self._db_database = db_database self._db_base_dir = db_base_dir self._db_addr = db_addr self._db_password = db_password self._db_username = db_username self._rank_id = rank_id self._options = options self._use_mock_etcd = use_mock_etcd self._master_client = dp_grpc.DataPortalMasterServiceStub( master_channel)
def __init__(self, master_addr): self._master_addr = master_addr channel = make_insecure_channel(master_addr, ChannelType.INTERNAL) self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel) self._data_source = None self._raw_date_ctl = None self._raw_data_updated_datetime = {}
def __init__(self, addr): self._lock = threading.Lock() self._channel = make_insecure_channel(addr, ChannelType.REMOTE) self._stub = dj_grpc.RsaPsiSignServiceStub(self._channel) self._serial_fail_cnt = 0 self._rpc_ref_cnt = 0 self._mark_error = False
def __init__(self, role, listen_port, remote_address, app_id=None, rank=0, streaming_mode=True, compression=grpc.Compression.NoCompression): self._role = role self._listen_port = listen_port self._remote_address = remote_address if app_id is None: app_id = 'test_trainer' self._app_id = app_id self._rank = rank self._streaming_mode = streaming_mode self._compression = compression self._prefetch_handlers = [] self._data_block_handler_fn = None # Connection related self._connected = False self._terminated = False self._peer_terminated = False self._identifier = '%s-%s-%d-%d' % ( app_id, role, rank, int(time.time())) # Ensure unique per run self._peer_identifier = '' # data transmit self._condition = threading.Condition() self._current_iter_id = None self._next_iter_id = 0 self._received_data = {} # grpc client self._transmit_send_lock = threading.Lock() self._grpc_options = [ ('grpc.max_send_message_length', 2**31-1), ('grpc.max_receive_message_length', 2**31-1) ] self._channel = make_insecure_channel( remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) self._client = tws_grpc.TrainerWorkerServiceStub(self._channel) self._next_send_seq_num = 0 self._transmit_queue = queue.Queue() self._client_daemon = None self._client_daemon_shutdown_fn = None # server self._transmit_receive_lock = threading.Lock() self._next_receive_seq_num = 0 self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=self._grpc_options, compression=self._compression) tws_grpc.add_TrainerWorkerServiceServicer_to_server( Bridge.TrainerWorkerServicer(self), self._server) self._server.add_insecure_port('[::]:%d' % listen_port)
def _transmit(self, msg): assert self._connected, "Cannot transmit before connect" with self._transmit_send_lock: msg.seq_num = self._next_send_seq_num self._next_send_seq_num += 1 if self._streaming_mode: self._transmit_queue.put(msg) return while True: try: rsp = self._client.Transmit(msg) assert rsp.status.code == common_pb.STATUS_SUCCESS, \ "Transmit error with code %d."%rsp.status.code break except Exception as e: # pylint: disable=broad-except logging.warning("Bridge transmit failed: %s. " \ "Retry in 1 second...", repr(e)) self._channel.close() time.sleep(1) self._channel = make_insecure_channel( self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) self._client = make_ready_client(self._channel) self._check_remote_heartbeat()
def __init__(self, id_batch_fetcher, max_flying_item, process_pool_executor, public_key, leader_signer_addr): super(FollowerPsiRsaSigner, self).__init__(id_batch_fetcher, max_flying_item, process_pool_executor) self._public_key = public_key channel = make_insecure_channel(leader_signer_addr, ChannelType.REMOTE) self._leader_signer_stub = dj_grpc.RsaPsiSignServiceStub(channel)
def __init__(self, addr, role, task_id): self._addr = addr self._role = role self._task_id = task_id channel = make_insecure_channel(self._addr, ChannelType.INTERNAL) self._stub = tm_grpc.TrainerMasterServiceStub(channel) self._request = tm_pb.DataBlockRequest() if self._role == 'leader': self._request.worker_rank = self._task_id
def __init__(self, address, worker_rank): channel = make_insecure_channel( address, mode=ChannelType.INTERNAL, options=( ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ('grpc.max_reconnect_backoff_ms', 1000), )) client = tm_grpc.TrainerMasterServiceStub(channel) super(TrainerMasterClient, self).__init__(client, worker_rank)
def __init__(self, addr): self._lock = threading.Lock() self._channel = make_insecure_channel( addr, ChannelType.REMOTE, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) self._stub = dj_grpc.RsaPsiSignServiceStub(self._channel) self._serial_fail_cnt = 0 self._rpc_ref_cnt = 0 self._mark_error = False
def __init__(self, options, master_addr, rank_id, etcd_name, etcd_base_dir, etcd_addrs, use_mock_etcd=False): master_channel = make_insecure_channel( master_addr, ChannelType.INTERNAL) self._etcd_name = etcd_name self._etcd_base_dir = etcd_base_dir self._etcd_addrs = etcd_addrs self._rank_id = rank_id self._options = options self._use_mock_etcd = use_mock_etcd self._master_client = dp_grpc.DataPortalMasterServiceStub( master_channel)
def __init__(self, listen_port, peer_addr, data_source_name, etcd_name, etcd_base_dir, etcd_addrs): channel = make_insecure_channel(peer_addr, ChannelType.REMOTE) peer_client = dj_grpc.DataJoinMasterServiceStub(channel) self._data_source_name = data_source_name self._listen_port = listen_port self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) self._dim = DataJoinMaster(peer_client, data_source_name, etcd_name, etcd_addrs, etcd_base_dir) dj_grpc.add_DataJoinMasterServiceServicer_to_server( self._dim, self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def _launch_masters(self): self._master_addr_l = 'localhost:4061' self._master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) self._master_l = data_join_master.DataJoinMasterService( int(self._master_addr_l.split(':')[1]), self._master_addr_f, self._data_source_name, self._db_database, self._db_base_dir_l, self._db_addr, self._db_username_l, self._db_password_l, master_options) self._master_f = data_join_master.DataJoinMasterService( int(self._master_addr_f.split(':')[1]), self._master_addr_l, self._data_source_name, self._db_database, self._db_base_dir_f, self._db_addr, self._db_username_f, self._db_password_f, master_options) self._master_f.start() self._master_l.start() channel_l = make_insecure_channel(self._master_addr_l, ChannelType.INTERNAL) self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(self._master_addr_f, ChannelType.INTERNAL) self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) logging.info("masters turn into Processing state")
def __init__(self, role, listen_port, remote_address, app_id='test_trainer', rank=0, streaming_mode=True): self._role = role self._listen_port = listen_port self._remote_address = remote_address self._app_id = app_id self._rank = rank self._streaming_mode = streaming_mode self._prefetch_handlers = [] self._data_block_handler_fn = None self._connected = False # data transmit self._condition = threading.Condition() self._current_iter_id = None self._next_iter_id = 0 self._received_data = {} # grpc client self._grpc_options = [ ('grpc.max_send_message_length', 2**31-1), ('grpc.max_receive_message_length', 2**31-1) ] channel = make_insecure_channel( remote_address, ChannelType.REMOTE, options=self._grpc_options) self._transmit_send_lock = threading.Lock() self._client = tws_grpc.TrainerWorkerServiceStub(channel) self._next_send_seq_num = 0 self._transmit_queue = queue.Queue() self._client_daemon = None self._client_daemon_shutdown_fn = None self._keepalive_daemon = None self._keepalive_daemon_shutdown_fn = None # server self._transmit_receive_lock = threading.Lock() self._next_receive_seq_num = 0 self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=self._grpc_options) tws_grpc.add_TrainerWorkerServiceServicer_to_server( Bridge.TrainerWorkerServicer(self), self._server) self._server.add_insecure_port('[::]:%d' % listen_port)
def _rpc_with_retry(self, sender, err_log): while True: with self._client_lock: try: return sender() except Exception as e: # pylint: disable=broad-except logging.warning("%s: %s. Retry in 1s...", err_log, repr(e)) self._channel.close() time.sleep(1) self._channel = make_insecure_channel( self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) self._client = make_ready_client(self._channel) self._check_remote_heartbeat(self._client)
def __init__(self, options, master_addr, rank_id, kvstore_type, use_mock_etcd=False): master_channel = make_insecure_channel( master_addr, ChannelType.INTERNAL, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) self._kvstore_type = kvstore_type self._rank_id = rank_id self._options = options self._use_mock_etcd = use_mock_etcd self._master_client = dp_grpc.DataPortalMasterServiceStub( master_channel)
def __init__(self, listen_port, peer_addr, data_source_name, kvstore_type, options): channel = make_insecure_channel( peer_addr, ChannelType.REMOTE, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) peer_client = dj_grpc.DataJoinMasterServiceStub(channel) self._data_source_name = data_source_name self._listen_port = listen_port self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) self._data_join_master = DataJoinMaster(peer_client, data_source_name, kvstore_type, options) dj_grpc.add_DataJoinMasterServiceServicer_to_server( self._data_join_master, self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
def _client_daemon_fn(self): stop_event = threading.Event() generator = None channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) lock = threading.Lock() resend_list = collections.deque() def shutdown_fn(): while True: with lock: if len(resend_list) > 0: logging.debug( "Waiting for resend queue's being cleaned. " "Resend queue size: %d", len(resend_list)) time.sleep(1) else: logging.debug('Resend queue is empty and we can shut ' 'down client daemon safely.') break stop_event.set() if generator is not None: generator.cancel() return generator.result() self._client_daemon_shutdown_fn = shutdown_fn while not stop_event.is_set(): try: def iterator(): with lock: resend_msgs = list(resend_list) for item in resend_msgs: logging.warning("Streaming resend message seq_num=%d", item.seq_num) yield item while True: item = self._transmit_queue.get() with lock: resend_list.append(item) logging.debug("Streaming send message seq_num=%d", item.seq_num) yield item generator = client.StreamTransmit(iterator()) for response in generator: if response.status.code == common_pb.STATUS_SUCCESS: logging.debug( "Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_DUPLICATED: logging.debug( "Resent Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_MISSING: raise RuntimeError("Message with seq_num=%d is " "missing!" % (response.next_seq_num - 1)) else: raise RuntimeError("Trainsmit failed with %d" % response.status.code) with lock: while resend_list and \ resend_list[0].seq_num < response.next_seq_num: resend_list.popleft() min_seq_num_to_resend = resend_list[0].seq_num \ if resend_list else "NaN" logging.debug( "Resend queue size: %d, starting from seq_num=%s", len(resend_list), min_seq_num_to_resend) except Exception as e: # pylint: disable=broad-except if not stop_event.is_set(): logging.warning("Bridge streaming broken: %s.", repr(e)) finally: generator.cancel() channel.close() logging.warning( "Restarting streaming: resend queue size: %d, " "starting from seq_num=%s", len(resend_list), resend_list and resend_list[0].seq_num or "NaN") channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) self._check_remote_heartbeat()
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) kvstore_type = 'etcd' db_base_dir = 'dp_test' os.environ['ETCD_BASE_DIR'] = db_base_dir data_portal_name = 'test_data_source' kvstore = DBClient(kvstore_type, True) kvstore.delete_prefix(db_base_dir) portal_input_base_dir = './portal_upload_dir' portal_output_base_dir = './portal_output_dir' raw_data_publish_dir = 'raw_data_publish_dir' portal_manifest = dp_pb.DataPortalManifest( name=data_portal_name, data_portal_type=dp_pb.DataPortalType.Streaming, output_partition_num=4, input_file_wildcard="*.done", input_base_dir=portal_input_base_dir, output_base_dir=portal_output_base_dir, raw_data_publish_dir=raw_data_publish_dir, processing_job_id=-1, next_job_id=0) kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name), text_format.MessageToString(portal_manifest)) if gfile.Exists(portal_input_base_dir): gfile.DeleteRecursively(portal_input_base_dir) gfile.MakeDirs(portal_input_base_dir) all_fnames = ['1001/{}.done'.format(i) for i in range(100)] all_fnames.append('{}.xx'.format(100)) all_fnames.append('1001/_SUCCESS') for fname in all_fnames: fpath = os.path.join(portal_input_base_dir, fname) gfile.MakeDirs(os.path.dirname(fpath)) with gfile.Open(fpath, "w") as f: f.write('xxx') portal_master_addr = 'localhost:4061' portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=True, ) data_portal_master = DataPortalMasterService( int(portal_master_addr.split(':')[1]), data_portal_name, kvstore_type, portal_options) data_portal_master.start() channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL) portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel) recv_manifest = portal_master_cli.GetDataPortalManifest( empty_pb2.Empty()) self.assertEqual(recv_manifest.name, portal_manifest.name) self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type) self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num) self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard) self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir) self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir) self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir) self.assertEqual(recv_manifest.next_job_id, 1) self.assertEqual(recv_manifest.processing_job_id, 0) self._check_portal_job(kvstore, all_fnames, portal_manifest, 0) mapped_partition = set() task_0 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) task_0_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_0, task_0_1) self.assertTrue(task_0.HasField('map_task')) mapped_partition.add(task_0.map_task.partition_id) self._check_map_task(task_0.map_task, all_fnames, task_0.map_task.partition_id, portal_manifest) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=0, partition_id=task_0.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) task_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertTrue(task_1.HasField('map_task')) mapped_partition.add(task_1.map_task.partition_id) self._check_map_task(task_1.map_task, all_fnames, task_1.map_task.partition_id, portal_manifest) task_2 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_2.HasField('map_task')) mapped_partition.add(task_2.map_task.partition_id) self._check_map_task(task_2.map_task, all_fnames, task_2.map_task.partition_id, portal_manifest) task_3 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_3.HasField('map_task')) mapped_partition.add(task_3.map_task.partition_id) self._check_map_task(task_3.map_task, all_fnames, task_3.map_task.partition_id, portal_manifest) self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=0, partition_id=task_1.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) pending_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=4)) self.assertTrue(pending_1.HasField('pending')) pending_2 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(pending_2.HasField('pending')) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=1, partition_id=task_2.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=2, partition_id=task_3.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) reduce_partition = set() task_4 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) task_4_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_4, task_4_1) self.assertTrue(task_4.HasField('reduce_task')) reduce_partition.add(task_4.reduce_task.partition_id) self._check_reduce_task(task_4.reduce_task, task_4.reduce_task.partition_id, portal_manifest) task_5 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_5.HasField('reduce_task')) reduce_partition.add(task_5.reduce_task.partition_id) self._check_reduce_task(task_5.reduce_task, task_5.reduce_task.partition_id, portal_manifest) task_6 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_6.HasField('reduce_task')) reduce_partition.add(task_6.reduce_task.partition_id) self._check_reduce_task(task_6.reduce_task, task_6.reduce_task.partition_id, portal_manifest) task_7 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(task_7.HasField('reduce_task')) reduce_partition.add(task_7.reduce_task.partition_id) self.assertEqual(len(reduce_partition), 4) self._check_reduce_task(task_7.reduce_task, task_7.reduce_task.partition_id, portal_manifest) task_8 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_8.HasField('pending')) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_4.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_5.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_6.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=3, partition_id=task_7.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) task_9 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_9.HasField('finished')) data_portal_master.stop() gfile.DeleteRecursively(portal_input_base_dir)
def __init__(self, listen_address, remote_address, token=None, max_workers=16, compression=grpc.Compression.Gzip, heartbeat_timeout=120, retry_interval=2, stats_client=None): # identifier self._identifier = uuid.uuid4().hex[:16] self._peer_identifier = "" self._token = token if token else "" # lock & condition self._lock = threading.RLock() self._condition = threading.Condition(self._lock) # heartbeat if heartbeat_timeout <= 0: raise ValueError("[Channel] heartbeat_timeout must be positive") self._heartbeat_timeout = heartbeat_timeout self._heartbeat_interval = self._heartbeat_timeout / 6 self._next_heartbeat_at = 0 self._heartbeat_timeout_at = 0 self._peer_heartbeat_timeout_at = 0 self._connected_at = 0 self._closed_at = 0 self._peer_connected_at = 0 self._peer_closed_at = 0 if retry_interval <= 0: raise ValueError("[Channel] retry_interval must be positive") self._retry_interval = retry_interval self._next_retry_at = 0 self._ready_event = threading.Event() self._closed_event = threading.Event() self._error_event = threading.Event() # channel state self._state = Channel.State.IDLE self._state_thread = None self._error = None self._event_callbacks = {} # channel self._remote_address = remote_address self._channel = make_insecure_channel( self._remote_address, mode=ChannelType.REMOTE, options=( ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ('grpc.max_reconnect_backoff_ms', 1000), ), compression=compression) self._channel_interceptor = ClientInterceptor( identifier=self._identifier, retry_interval=self._retry_interval, wait_fn=self.wait_for_ready, check_fn=self._channel_response_check_fn, stats_client=stats_client) self._channel = grpc.intercept_channel(self._channel, self._channel_interceptor) # server self._listen_address = listen_address self._server_thread_pool = futures.ThreadPoolExecutor( max_workers=max_workers) self._server_interceptor = ServerInterceptor() self._server = grpc.server(self._server_thread_pool, options=( ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ), interceptors=(self._server_interceptor, ), compression=compression) self._server.add_insecure_port(self._listen_address) # channel client & server self._channel_call = channel_pb2_grpc.ChannelStub(self._channel) channel_pb2_grpc.add_ChannelServicer_to_server(Channel._Servicer(self), self._server) # stats self._stats_client = stats_client or stats.NoneClient()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 13
def test_api(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 32 data_source_meta.max_matching_window = 1024 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertFalse(rdrsp.HasField('manifest')) self.assertFalse(rdrsp.HasField('finished')) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(frreq) self.assertEqual(frrsp.code, 0) rdrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) master_l.stop() master_f.stop()
def __init__(self, addr): self._addr = addr channel = make_insecure_channel(addr, mode=ChannelType.REMOTE) self._stub = ss_grpc.SchedulerStub(channel)
'only support add/finish') parser.add_argument('master_addr', type=str, help='the addr(uuid) of local data_join_master') parser.add_argument('partition_id', type=int, help='the partition to control') parser.add_argument('--files', type=str, nargs='+', help='the need raw data fnames') parser.add_argument('--src_dir', type=str, help='the directory of input raw data. The input '\ 'file sequence is sorted by file name and rank '\ 'after raw data input by --files') parser.add_argument('--dedup', action='store_true', help='dedup the input files, otherwise, '\ 'error if dup input files') args = parser.parse_args() master_channel = make_insecure_channel(args.master_addr, ChannelType.INTERNAL) master_cli = dj_grpc.DataJoinMasterServiceStub(master_channel) data_src = master_cli.GetDataSource(empty_pb2.Empty()) rdc = RawDataController(data_src, master_cli) if args.cmd == 'add': all_fpaths = [] if args.files is not None: for fp in args.files: all_fpaths.append(fp) if args.src_dir is not None: dir_fpaths = \ [path.join(args.src_dir, f) for f in gfile.ListDirectory(args.src_dir) if not gfile.IsDirectory(path.join(args.src_dir, f))] dir_fpaths.sort() all_fpaths += dir_fpaths
def _inner_test_round(self, start_index): for i in range(self.data_source_l.data_source_meta.partition_num): self.generate_raw_data( start_index, self.etcd_l, self.raw_data_publisher_l, self.data_source_l, self.raw_data_dir_l, i, 2048, 64, 'leader_key_partition_{}'.format(i) + ':{}', 'leader_value_partition_{}'.format(i) + ':{}' ) self.generate_raw_data( start_index, self.etcd_f, self.raw_data_publisher_f, self.data_source_f, self.raw_data_dir_f, i, 4096, 128, 'follower_key_partition_{}'.format(i) + ':{}', 'follower_value_partition_{}'.format(i) + ':{}' ) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True, batch_mode=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, self.data_source_name, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, self.data_source_name, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, master_options ) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break except Exception as e: pass time.sleep(2) worker_addr_l = 'localhost:4161' worker_addr_f = 'localhost:4162' worker_l = data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, master_addr_l, 0, self.etcd_name, self.etcd_base_dir_l, self.etcd_addrs, self.worker_options ) worker_f = data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, master_addr_f, 0, self.etcd_name, self.etcd_base_dir_f, self.etcd_addrs, self.worker_options ) th_l = threading.Thread(target=worker_l.run, name='worker_l') th_f = threading.Thread(target=worker_f.run, name='worker_f') th_l.start() th_f.start() while True: try: req_l = dj_pb.DataSourceRequest( data_source_meta=self.data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=self.data_source_f.data_source_meta ) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Ready and \ dss_f.state == common_pb.DataSourceState.Ready: break except Exception as e: #xx pass time.sleep(2) th_l.join() th_f.join() master_l.stop() master_f.stop()
def _client_daemon_fn(self): stop_event = threading.Event() generator = None channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) def shutdown_fn(): while self._transmit_queue.size(): logging.debug( "Waiting for message queue's being cleaned. " "Queue size: %d", self._transmit_queue.size()) time.sleep(1) stop_event.set() if generator is not None: generator.cancel() self._client_daemon_shutdown_fn = shutdown_fn while not stop_event.is_set(): try: def iterator(): while True: item = self._transmit_queue.get() logging.debug("Streaming send message seq_num=%d", item.seq_num) yield item generator = client.StreamTransmit(iterator()) for response in generator: if response.status.code == common_pb.STATUS_SUCCESS: self._transmit_queue.confirm(response.next_seq_num) logging.debug( "Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_DUPLICATED: self._transmit_queue.confirm(response.next_seq_num) logging.debug( "Resent Message with seq_num=%d is " "confirmed", response.next_seq_num - 1) elif response.status.code == \ common_pb.STATUS_MESSAGE_MISSING: self._transmit_queue.resend(response.next_seq_num) else: raise RuntimeError("Trainsmit failed with %d" % response.status.code) except Exception as e: # pylint: disable=broad-except if not stop_event.is_set(): logging.warning("Bridge streaming broken: %s.", repr(e)) metrics.emit_counter('reconnect_counter', 1) finally: generator.cancel() channel.close() time.sleep(1) self._transmit_queue.resend(-1) channel = make_insecure_channel(self._remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) client = make_ready_client(channel, stop_event) self._check_remote_heartbeat(client)
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 64 data_source_meta.max_matching_window = 128 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) customized_options.set_use_mock_etcd() master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = master_client_l.GetDataSourceState( data_source_l.data_source_meta) rsp_f = master_client_f.GetDataSourceState( data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.total_index = 1 << 13