def make_insecure_channel(uuid, mode=ChannelType.INTERNAL, options=None, compression=None): if check_address_valid(uuid): return grpc.insecure_channel(uuid, options, compression) if mode == ChannelType.REMOTE: header_adder = header_adder_interceptor('uuid', uuid) if not INTERNAL_PROXY: raise Exception("INTERNAL_PROXY is invalid," "not found in environment variable.") logging.debug("INTERNAL_PROXY is [%s]", INTERNAL_PROXY) channel = grpc.insecure_channel(INTERNAL_PROXY, options, compression) return grpc.intercept_channel(channel, header_adder) if mode == ChannelType.INTERNAL: if not ETCD_CLUSTER or not ETCD_ADDRESS: raise Exception( "ETCD_CLUSTER or ETCD_ADDRESS is invalid, not found in" " environment variable.") etcd_client = EtcdClient(ETCD_CLUSTER, ETCD_ADDRESS, ETCD_PATH) target_addr = etcd_client.get_data(uuid) if not target_addr: raise Exception( "Target service address cant discover by uuid [{}]".format( uuid)) return grpc.insecure_channel(target_addr, options, compression) raise Exception("UNKNOWN Channel by uuid %s" % uuid)
def _setUpEtcd(self): self._etcd_name = 'test_etcd' self._etcd_addrs = 'localhost:2379' self._etcd_base_dir_l = 'byefl_l' self._etcd_base_dir_f = 'byefl_f' self._etcd_l = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) self._etcd_f = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True)
def __init__(self, data_source_name, etcd_name, etcd_base_dir, etcd_addrs, use_mock_etcd=False): self._etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) self._data_source = retrieve_data_source(self._etcd, data_source_name)
def __init__(self, database, addr, username, password, base_dir, use_mock_etcd=False): self._client = EtcdClient(database, addr, base_dir, use_mock_etcd) if username is not None and not use_mock_etcd: self._client = MySQLClient(database, addr, username, password, base_dir)
def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name, etcd_base_dir, etcd_addrs, options): master_channel = make_insecure_channel( master_addr, ChannelType.INTERNAL, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) self._master_client = dj_grpc.DataJoinMasterServiceStub(master_channel) self._rank_id = rank_id etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, options.use_mock_etcd) data_source = self._sync_data_source() self._data_source_name = data_source.data_source_meta.name self._listen_port = listen_port self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) peer_channel = make_insecure_channel( peer_addr, ChannelType.REMOTE, options=[('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)]) peer_client = dj_grpc.DataJoinWorkerServiceStub(peer_channel) self._data_join_worker = DataJoinWorker(peer_client, self._master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinWorkerServiceServicer_to_server( self._data_join_worker, self._server) self._role_repr = "leader" if data_source.role == \ common_pb.FLRole.Leader else "follower" self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name, etcd_base_dir, etcd_addrs, options): master_channel = make_insecure_channel(master_addr, ChannelType.INTERNAL) master_client = dj_grpc.DataJoinMasterServiceStub(master_channel) etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir) data_source = self.sync_data_source(master_client) self._data_source_name = data_source.data_source_meta.name self._listen_port = listen_port self._rank_id = rank_id self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) peer_channel = make_insecure_channel(peer_addr, ChannelType.REMOTE) if data_source.role == common_pb.FLRole.Leader: self._role_repr = "leader" peer_client = dj_grpc.DataJoinFollowerServiceStub(peer_channel) self._diw = data_join_leader.DataJoinLeader( peer_client, master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinLeaderServiceServicer_to_server( self._diw, self._server) else: assert data_source.role == common_pb.FLRole.Follower self._role_repr = "follower" peer_client = dj_grpc.DataJoinLeaderServiceStub(peer_channel) self._diw = data_join_follower.DataJoinFollower( peer_client, master_client, rank_id, etcd, data_source, options) dj_grpc.add_DataJoinFollowerServiceServicer_to_server( self._diw, self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def __init__(self, peer_client, data_source_name, etcd_name, etcd_addrs, etcd_base_dir): super(DataJoinMaster, self).__init__() self._data_source_name = data_source_name etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, customized_options.get_use_mock_etcd()) self._fsm = MasterFSM(peer_client, data_source_name, etcd)
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( dj_pb.SortRunMergerOptions( merger_name='sort_run_merger_'+\ partition_repr(options.partition_id), reader_options=dj_pb.RawDataOptions( raw_data_iter=options.writer_options.output_writer, compressed_type=options.writer_options.compressed_type, read_ahead_size=\ options.sort_run_merger_read_ahead_buffer ), writer_options=options.writer_options, output_file_dir=options.output_file_dir, partition_id=options.partition_id ), 'example_id' ) self._started = False
def __init__(self, peer_client, data_source_name, etcd_name, etcd_addrs, etcd_base_dir, options): super(DataJoinMaster, self).__init__() self._data_source_name = data_source_name etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, options.use_mock_etcd) self._fsm = MasterFSM(peer_client, data_source_name, etcd) self._data_source_meta = \ self._fsm.get_data_source().data_source_meta
def _cheif_barriar(self, is_chief=False, sync_times=300): worker_replicas = os.environ.get('REPLICA_NUM', 0) etcd_client = EtcdClient(os.environ['ETCD_CLUSTER'], os.environ['ETCD_ADDRESS'], SYNC_PATH) sync_path = '%s/%s' % (os.environ['APPLICATION_ID'], os.environ['WORKER_RANK']) logging.info('Creating a sync flag at %s', sync_path) etcd_client.set_data(sync_path, 1) if is_chief: for _ in range(sync_times): sync_list = etcd_client.get_prefix_kvs( os.environ['APPLICATION_ID']) logging.info('Sync file pattern is: %s', sync_list) if len(sync_list) < worker_replicas: logging.info('Count of ready workers is %d', len(sync_list)) time.sleep(6) else: break
def __init__(self, portal_name, etcd_name, etcd_addrs, etcd_base_dir, portal_options): super(DataJoinPortal, self).__init__() self._portal_name = portal_name self._etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, portal_options.use_mock_etcd) self._portal_manifest = retrieve_portal_manifest( self._etcd, self._portal_name) self._portal_options = portal_options self._portal_repartitioner = PortalRepartitioner( self._etcd, self._portal_manifest, self._portal_options)
def __init__(self, kvstore_type, use_mock_etcd=False): if kvstore_type == 'dfs': base_dir = os.path.join( os.environ.get('STORAGE_ROOT_PATH', '/fedlearner'), 'metadata') self._client = DFSClient(base_dir) else: database, addr, username, password, base_dir = \ get_kvstore_config(kvstore_type) self._client = EtcdClient(database, addr, base_dir, use_mock_etcd) if username is not None and not use_mock_etcd: self._client = MySQLClient(database, addr, username, password, base_dir)
def __init__(self, listen_port, portal_name, etcd_name, etcd_base_dir, etcd_addrs, portal_options): self._portal_name = portal_name self._listen_port = listen_port self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, portal_options.use_mock_etcd) self._data_portal_master = DataPortalMaster(portal_name, etcd, portal_options) dp_grpc.add_DataPortalMasterServiceServicer_to_server( self._data_portal_master, self._server) self._server.add_insecure_port('[::]:%d' % listen_port) self._server_started = False
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) self._raw_data_batch_fetcher = RawDataBatchFetcher(etcd, options) self._next_part_index = None self._dumped_process_index = None self._flying_writers = [] self._dumped_file_metas = {} self._worker_map = {} self._started = False self._part_finished = False self._cond = threading.Condition()
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._thread_pool_rpc_executor = None if options.rpc_sync_mode: assert options.rpc_thread_pool_size > 0 self._thread_pool_rpc_executor = concur_futures.ThreadPoolExecutor( options.rpc_thread_pool_size) self._id_batch_fetcher = IdBatchFetcher(etcd, self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = rsa.PrivateKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.slow_sign_threshold, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = rsa.PublicKey.load_pkcs1(options.rsa_key_pem) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._options.max_flying_sign_batch, self._options.max_flying_sign_rpc, self._options.sign_rpc_timeout_ms, self._options.slow_sign_threshold, self._options.stub_fanout, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr, self._thread_pool_rpc_executor) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( self._sort_run_dumper.sort_run_dump_dir(), self._options) self._started = False
def __init__(self, options, etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd=False): self._lock = threading.Condition() self._options = options etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) pub_dir = self._options.raw_data_publish_dir self._publisher = RawDataPublisher(etcd, pub_dir) self._process_pool_executor = \ concur_futures.ProcessPoolExecutor( options.offload_processor_number ) self._id_batch_fetcher = IdBatchFetcher(self._options) max_flying_item = options.batch_processor_options.max_flying_item if self._options.role == common_pb.FLRole.Leader: private_key = None with gfile.GFile(options.rsa_key_file_path, 'rb') as f: file_content = f.read() private_key = rsa.PrivateKey.load_pkcs1(file_content) self._psi_rsa_signer = LeaderPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._process_pool_executor, private_key, ) self._repr = 'leader-' + 'rsa_psi_preprocessor' else: public_key = None with gfile.GFile(options.rsa_key_file_path, 'rb') as f: file_content = f.read() public_key = rsa.PublicKey.load_pkcs1(file_content) self._psi_rsa_signer = FollowerPsiRsaSigner( self._id_batch_fetcher, max_flying_item, self._process_pool_executor, public_key, self._options.leader_rsa_psi_signer_addr) self._repr = 'follower-' + 'rsa_psi_preprocessor' self._sort_run_dumper = SortRunDumper(options) self._sort_run_merger = SortRunMerger( self._sort_run_dumper.sort_run_dump_dir, self._options) self._worker_map = {} self._started = False
def __init__(self, portal_name, etcd_name, etcd_addrs, etcd_base_dir, portal_options): super(DataJoinPortal, self).__init__() self._portal_name = portal_name self._etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, portal_options.use_mock_etcd) self._portal_manifest = retrieve_portal_manifest( self._etcd, self._portal_name) self._portal_options = portal_options self._portal_repartitioner = PortalRepartitioner( self._etcd, self._portal_manifest, self._portal_options) if len(self._portal_options.downstream_data_source_masters) > 0: self._portal_raw_data_notifier = PortalRawDataNotifier( self._etcd, self._portal_name, self._portal_options.downstream_data_source_masters) logging.info("launch data join portal with raw data notifier "\ "for following downstream data source masters:") for master_addr in \ self._portal_options.downstream_data_source_masters: logging.info(master_addr) else: self._portal_raw_data_notifier = None logging.info("launch data join portal without raw data notifier")
def test_api(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 32 data_source_meta.max_matching_window = 1024 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertFalse(rdrsp.HasField('manifest')) self.assertFalse(rdrsp.HasField('finished')) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(frreq) self.assertEqual(frrsp.code, 0) rdrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) master_l.stop() master_f.stop()
help='the namespace of etcd key') parser.add_argument('--raw_data_sub_dir', type=str, required=True, help='the etcd base dir to subscribe new raw data') args = parser.parse_args() data_source = common_pb.DataSource() data_source.data_source_meta.name = args.data_source_name data_source.data_source_meta.partition_num = args.partition_num data_source.data_source_meta.start_time = args.start_time data_source.data_source_meta.end_time = args.end_time data_source.data_source_meta.negative_sampling_rate = \ args.negative_sampling_rate if args.role == 'leader': data_source.role = common_pb.FLRole.Leader else: assert args.role == 'follower' data_source.role = common_pb.FLRole.Follower data_source.example_dumped_dir = args.example_dump_dir data_source.data_block_dir = args.data_block_dir data_source.raw_data_sub_dir = args.raw_data_sub_dir data_source.state = common_pb.DataSourceState.Init etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir) master_etcd_key = os.path.join(data_source.data_source_meta.name, 'master') raw_data = etcd.get_data(master_etcd_key) if raw_data is None: logging.info("data source %s is not existed", args.data_source_name) etcd.set_data(master_etcd_key, text_format.MessageToString(data_source)) logging.info("apply new data source %s", args.data_source_name) else: logging.info("data source %s has been existed", args.data_source_name) etcd.destory_client_pool()
required=True, help='the etcd base dir to subscribe new raw data') args = parser.parse_args() data_source = common_pb.DataSource() data_source.data_source_meta.name = args.data_source_name data_source.data_source_meta.partition_num = args.partition_num data_source.data_source_meta.start_time = args.start_time data_source.data_source_meta.end_time = args.end_time data_source.data_source_meta.negative_sampling_rate = \ args.negative_sampling_rate if args.role == 'leader': data_source.role = common_pb.FLRole.Leader else: assert args.role == 'follower' data_source.role = common_pb.FLRole.Follower data_source.example_dumped_dir = args.example_dump_dir data_source.data_block_dir = args.data_block_dir data_source.raw_data_sub_dir = args.raw_data_sub_dir data_source.state = common_pb.DataSourceState.Init etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir) master_etcd_key = common.data_source_etcd_base_dir( data_source.data_source_meta.name) raw_data = etcd.get_data(master_etcd_key) if raw_data is None: logging.info("data source %s is not existed", args.data_source_name) common.commit_data_source(etcd, data_source) logging.info("apply new data source %s", args.data_source_name) else: logging.info("data source %s has been existed", args.data_source_name) etcd.destroy_client_pool()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f= 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix(common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" self.raw_data_dir_l = "./raw_data_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" self.raw_data_dir_f = "./raw_data_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.data_source_name = data_source_name self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l ) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f ) if gfile.Exists(data_source_l.output_base_dir): gfile.DeleteRecursively(data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) if gfile.Exists(data_source_f.output_base_dir): gfile.DeleteRecursively(data_source_f.output_base_dir) if gfile.Exists(self.raw_data_dir_f): gfile.DeleteRecursively(self.raw_data_dir_f) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions( raw_data_iter='TF_RECORD', read_ahead_size=1<<20, read_batch_size=128 ), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024 ), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000 ), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048 ), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD' ) ) self.total_index = 1 << 12
type=str, required=True, help='the base dir of output directory') parser.add_argument('--raw_data_publish_dir', type=str, required=True, help='the raw data publish dir in etcd') parser.add_argument('--use_mock_etcd', action='store_true', help='use to mock etcd for test') parser.add_argument('--long_running', action='store_true', help='make the data portal long running') args = parser.parse_args() etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir, args.use_mock_etcd) etcd_key = common.portal_etcd_base_dir(args.data_portal_name) if etcd.get_data(etcd_key) is None: portal_manifest = dp_pb.DataPortalManifest( name=args.data_portal_name, data_portal_type=(dp_pb.DataPortalType.PSI if args.data_portal_type == 'PSI' else dp_pb.DataPortalType.Streaming), output_partition_num=args.output_partition_num, input_file_wildcard=args.input_file_wildcard, input_base_dir=args.input_base_dir, output_base_dir=args.output_base_dir, raw_data_publish_dir=args.raw_data_publish_dir, processing_job_id=-1) etcd.set_data(etcd_key, text_format.MessageToString(portal_manifest)) options = dp_pb.DataPotraMasterlOptions(use_mock_etcd=args.use_mock_etcd,
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.output_base_dir = "./ds_output_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.output_base_dir = "./ds_output_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty()) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty()) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 3) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=5)), dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rsp = client_l.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 5) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 0) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 1) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='a', timestamp=timestamp_pb2.Timestamp(seconds=1)), dj_pb.RawDataMeta( file_path='b', timestamp=timestamp_pb2.Timestamp(seconds=2)) ])) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) rsp = client_f.GetRawDataLatestTimeStamp(rdreq) self.assertEqual(rsp.status.code, 0) self.assertTrue(rsp.HasField('timestamp')) self.assertEqual(rsp.timestamp.seconds, 2) self.assertEqual(rsp.timestamp.nanos, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=False, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=4)) ])) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty()) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, added_raw_data_metas=dj_pb.AddedRawDataMetas( dedup=True, raw_data_metas=[ dj_pb.RawDataMeta( file_path='x', timestamp=timestamp_pb2.Timestamp(seconds=3)) ])) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
class DataJoinPortal(unittest.TestCase): def _setUpEtcd(self): self._etcd_name = 'test_etcd' self._etcd_addrs = 'localhost:2379' self._etcd_base_dir_l = 'byefl_l' self._etcd_base_dir_f = 'byefl_f' self._etcd_l = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) self._etcd_f = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) def _setUpDataSource(self): self._data_source_name = 'test_data_source' self._etcd_l.delete_prefix(self._data_source_name) self._etcd_f.delete_prefix(self._data_source_name) self._raw_data_pub_dir_l = './raw_data_pub_dir_l' self._raw_data_pub_dir_f = './raw_data_pub_dir_f' self._data_source_l = common_pb.DataSource() self._data_source_l.role = common_pb.FLRole.Leader self._data_source_l.state = common_pb.DataSourceState.Init self._data_source_l.data_block_dir = "./data_block_l" self._data_source_l.raw_data_dir = "./raw_data_l" self._data_source_l.example_dumped_dir = "./example_dumped_l" self._data_source_l.raw_data_sub_dir = self._raw_data_pub_dir_l self._data_source_f = common_pb.DataSource() self._data_source_f.role = common_pb.FLRole.Follower self._data_source_f.state = common_pb.DataSourceState.Init self._data_source_f.data_block_dir = "./data_block_f" self._data_source_f.raw_data_dir = "./raw_data_f" self._data_source_f.example_dumped_dir = "./example_dumped_f" self._data_source_f.raw_data_sub_dir = self._raw_data_pub_dir_f data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = self._data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 self._data_source_l.data_source_meta.MergeFrom(data_source_meta) self._data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(self._etcd_l, self._data_source_l) common.commit_data_source(self._etcd_f, self._data_source_f) def _setUpPortalManifest(self): self._portal_name = 'test_portal' self._etcd_l.delete_prefix(self._portal_name) self._etcd_f.delete_prefix(self._portal_name) self._portal_manifest_l = common_pb.DataJoinPortalManifest( name=self._portal_name, input_partition_num=4, output_partition_num=2, input_data_base_dir='./portal_input_l', output_data_base_dir='./portal_output_l', raw_data_publish_dir=self._raw_data_pub_dir_l, begin_timestamp=common.trim_timestamp_by_hourly( common.convert_datetime_to_timestamp(datetime.now()))) self._portal_manifest_f = common_pb.DataJoinPortalManifest( name=self._portal_name, input_partition_num=2, output_partition_num=2, input_data_base_dir='./portal_input_f', output_data_base_dir='./portal_output_f', raw_data_publish_dir=self._raw_data_pub_dir_f, begin_timestamp=common.trim_timestamp_by_hourly( common.convert_datetime_to_timestamp(datetime.now()))) common.commit_portal_manifest(self._etcd_l, self._portal_manifest_l) common.commit_portal_manifest(self._etcd_f, self._portal_manifest_f) def _launch_masters(self): self._master_addr_l = 'localhost:4061' self._master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) self._master_l = data_join_master.DataJoinMasterService( int(self._master_addr_l.split(':')[1]), self._master_addr_f, self._data_source_name, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, master_options) self._master_f = data_join_master.DataJoinMasterService( int(self._master_addr_f.split(':')[1]), self._master_addr_l, self._data_source_name, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, master_options) self._master_f.start() self._master_l.start() channel_l = make_insecure_channel(self._master_addr_l, ChannelType.INTERNAL) self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(self._master_addr_f, ChannelType.INTERNAL) self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) logging.info("masters turn into Processing state") def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER')) self._worker_addrs_l = ['localhost:4161', 'localhost:4162'] self._worker_addrs_f = ['localhost:5161', 'localhost:5162'] self._workers_l = [] self._workers_f = [] for rank_id in range(2): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start() def _launch_portals(self): portal_options = dj_pb.DataJoinPotralOptions( example_validator=dj_pb.ExampleValidatorOptions( example_validator='EXAMPLE_VALIDATOR', validate_event_time=True), reducer_buffer_size=1024, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), use_mock_etcd=True) self._portal_l = data_join_portal.DataJoinPortal( self._portal_name, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, portal_options) self._portal_f = data_join_portal.DataJoinPortal( self._portal_name, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, portal_options) self._portal_l.start() self._portal_f.start() def _remove_existed_dir(self): if gfile.Exists(self._portal_manifest_l.input_data_base_dir): gfile.DeleteRecursively( self._portal_manifest_l.input_data_base_dir) if gfile.Exists(self._portal_manifest_l.output_data_base_dir): gfile.DeleteRecursively( self._portal_manifest_l.output_data_base_dir) if gfile.Exists(self._portal_manifest_f.input_data_base_dir): gfile.DeleteRecursively( self._portal_manifest_f.input_data_base_dir) if gfile.Exists(self._portal_manifest_f.output_data_base_dir): gfile.DeleteRecursively( self._portal_manifest_f.output_data_base_dir) if gfile.Exists(self._data_source_l.data_block_dir): gfile.DeleteRecursively(self._data_source_l.data_block_dir) if gfile.Exists(self._data_source_l.raw_data_dir): gfile.DeleteRecursively(self._data_source_l.raw_data_dir) if gfile.Exists(self._data_source_l.example_dumped_dir): gfile.DeleteRecursively(self._data_source_l.example_dumped_dir) if gfile.Exists(self._data_source_f.data_block_dir): gfile.DeleteRecursively(self._data_source_f.data_block_dir) if gfile.Exists(self._data_source_f.raw_data_dir): gfile.DeleteRecursively(self._data_source_f.raw_data_dir) if gfile.Exists(self._data_source_f.example_dumped_dir): gfile.DeleteRecursively(self._data_source_f.example_dumped_dir) def _generate_portal_input_data(self, date_time, event_time_filter, start_index, total_item_num, portal_manifest): self.assertEqual(total_item_num % portal_manifest.input_partition_num, 0) item_step = portal_manifest.input_partition_num for partition_id in range(portal_manifest.input_partition_num): cands = list(range(partition_id, total_item_num, item_step)) for i in range(len(cands)): if random.randint(1, 4) > 1: continue a = random.randint(i - 16, i + 16) b = random.randint(i - 16, i + 16) if a < 0: a = 0 if a >= len(cands): a = len(cands) - 1 if b < 0: b = 0 if b >= len(cands): b = len(cands) - 1 if abs(cands[a] // item_step - b) <= 16 and abs(cands[b] // item_step - a) <= 16: cands[a], cands[b] = cands[b], cands[a] fpath = common.encode_portal_hourly_fpath( portal_manifest.input_data_base_dir, date_time, partition_id) if not gfile.Exists(os.path.dirname(fpath)): gfile.MakeDirs(os.path.dirname(fpath)) with tf.io.TFRecordWriter(fpath) as writer: for lid in cands: real_id = lid + start_index feat = {} example_id = '{}'.format(real_id).encode() feat['example_id'] = tf.train.Feature( bytes_list=tf.train.BytesList(value=[example_id])) # if test the basic example_validator for invalid event time if real_id == 0 or not event_time_filter(real_id): event_time = 150000000 + real_id feat['event_time'] = tf.train.Feature( int64_list=tf.train.Int64List(value=[event_time])) example = tf.train.Example(features=tf.train.Features( feature=feat)) writer.write(example.SerializeToString()) succ_tag_fpath = common.encode_portal_hourly_finish_tag( portal_manifest.input_data_base_dir, date_time) with gfile.GFile(succ_tag_fpath, 'w') as fh: fh.write('') def setUp(self): self._setUpEtcd() self._setUpDataSource() self._setUpPortalManifest() self._remove_existed_dir() self._item_num_l = 0 self._event_time_filter_l = lambda x: x % 877 == 0 self._dt_l = common.convert_timestamp_to_datetime( self._portal_manifest_l.begin_timestamp) for i in range(4): if i == 1: self._missing_datetime_l = self._dt_l self._missing_start_index_l = self._item_num_l self._missing_item_cnt_l = 1 << 13 self._item_num_l += self._missing_item_cnt_l else: self._generate_portal_input_data(self._dt_l, self._event_time_filter_l, self._item_num_l, 1 << 13, self._portal_manifest_l) self._item_num_l += 1 << 13 self._dt_l += timedelta(hours=1) self._item_num_f = 0 self._event_time_filter_f = lambda x: x % 907 == 0 self._dt_f = common.convert_timestamp_to_datetime( self._portal_manifest_f.begin_timestamp) for i in range(5): if i == 2: self._missing_datetime_f = self._dt_f self._missing_start_index_f = self._item_num_f self._missing_item_cnt_f = 1 << 13 else: self._generate_portal_input_data(self._dt_f, self._event_time_filter_f, self._item_num_f, 1 << 13, self._portal_manifest_f) self._item_num_f += 1 << 13 self._dt_f += timedelta(hours=1) self._launch_masters() self._launch_workers() self._launch_portals() def _stop_workers(self): for w in self._workers_f: w.stop() for w in self._workers_l: w.stop() def _stop_masters(self): self._master_f.stop() self._master_l.stop() def _stop_portals(self): self._portal_f.stop() self._portal_l.stop() def _wait_timestamp(self, target_l, target_f): while True: min_datetime_l = None min_datetime_f = None for pid in range( self._data_source_f.data_source_meta.partition_num): req_l = dj_pb.RawDataRequest( partition_id=pid, data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.RawDataRequest( partition_id=pid, data_source_meta=self._data_source_f.data_source_meta) rsp_l = self._master_client_l.GetRawDataLatestTimeStamp(req_l) rsp_f = self._master_client_f.GetRawDataLatestTimeStamp(req_f) datetime_l = common.convert_timestamp_to_datetime( rsp_l.timestamp) datetime_f = common.convert_timestamp_to_datetime( rsp_f.timestamp) if min_datetime_l is None or min_datetime_l > datetime_l: min_datetime_l = datetime_l if min_datetime_f is None or min_datetime_f > datetime_f: min_datetime_f = datetime_f if min_datetime_l >= target_l and min_datetime_f >= target_f: break else: time.sleep(2) def test_all_pipeline(self): self._wait_timestamp(self._missing_datetime_l - timedelta(hours=1), self._missing_datetime_f - timedelta(hours=1)) self._generate_portal_input_data(self._missing_datetime_l, self._event_time_filter_l, self._missing_start_index_l, 1 << 13, self._portal_manifest_l) self._generate_portal_input_data(self._missing_datetime_f, self._event_time_filter_f, self._missing_start_index_f, 1 << 13, self._portal_manifest_f) self._wait_timestamp(self._dt_l - timedelta(hours=1), self._dt_f - timedelta(hours=1)) self._generate_portal_input_data(self._dt_l, self._event_time_filter_l, self._item_num_l, 1 << 13, self._portal_manifest_l) self._dt_l += timedelta(hours=1) self.assertEqual(self._dt_f, self._dt_l) self._wait_timestamp(self._dt_l - timedelta(hours=1), self._dt_f - timedelta(hours=1)) data_source_l = self._master_client_l.GetDataSource(empty_pb2.Empty()) data_source_f = self._master_client_f.GetDataSource(empty_pb2.Empty()) rd_puber_l = raw_data_publisher.RawDataPublisher( self._etcd_l, self._raw_data_pub_dir_l) rd_puber_f = raw_data_publisher.RawDataPublisher( self._etcd_f, self._raw_data_pub_dir_f) for partition_id in range( data_source_l.data_source_meta.partition_num): rd_puber_f.finish_raw_data(partition_id) rd_puber_l.finish_raw_data(partition_id) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) logging.info("masters turn into Finished state") def tearDown(self): self._stop_portals() self._stop_masters() self._stop_workers() self._remove_existed_dir()
class DataBlockVisitor(object): def __init__(self, data_source_name, etcd_name, etcd_base_dir, etcd_addrs, use_mock_etcd=False): self._etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, use_mock_etcd) self._data_source = retrieve_data_source(self._etcd, data_source_name) def LoadDataBlockRepByTimeFrame(self, start_time=None, end_time=None): if (end_time is not None and end_time < self._data_source.data_source_meta.start_time) or \ (start_time is not None and start_time > self._data_source.data_source_meta.end_time): raise ValueError("time frame is out of range") partition_num = self._data_source.data_source_meta.partition_num data_block_fnames = {} for partition_id in range(0, partition_num): data_block_fnames[partition_id] = \ self._list_data_block(partition_id) data_block_reps = {} for partition_id, fnames in data_block_fnames.items(): manifest = self._sync_raw_data_manifest(partition_id) for idx, fname in enumerate(fnames): check_existed = (idx == len(fnames) - 1) rep = self._make_data_block_rep(partition_id, fname, check_existed) filtered = True reason = '' if rep is None: reason = 'failed to create data block rep' elif end_time is not None and rep.start_time > end_time: reason = 'excess time frame' elif start_time is not None and rep.end_time < start_time: reason = 'less time frame' elif self._filter_by_visible(rep.data_block_index, manifest): reason = 'data block visible' else: data_block_reps[rep.block_id] = rep filtered = False if filtered: logging.debug('skip %s since %s', fname, reason) return data_block_reps def LoadDataBlockReqByIndex(self, partition_id, data_block_index): partition_num = self._data_source.data_source_meta.partition_num if partition_id < 0 or partition_id >= partition_num: raise IndexError("partition {} out range".format(partition_id)) dirpath = self._partition_data_block_dir(partition_id) meta_fname = encode_data_block_meta_fname(self._data_source_name(), partition_id, data_block_index) meta_fpath = os.path.join(dirpath, meta_fname) meta = load_data_block_meta(meta_fpath) manifest = self._sync_raw_data_manifest(partition_id) if meta is not None and \ not self._filter_by_visible(meta.data_block_index, manifest): fname = encode_data_block_fname(self._data_source_name(), meta) return DataBlockRep(self._data_source_name(), fname, partition_id, dirpath) return None def _list_data_block(self, partition_id): dirpath = self._partition_data_block_dir(partition_id) if gfile.Exists(dirpath) and gfile.IsDirectory(dirpath): return [ f for f in gfile.ListDirectory(dirpath) if f.endswith(DataBlockSuffix) ] return [] def _partition_data_block_dir(self, partition_id): return os.path.join(self._data_source.data_block_dir, partition_repr(partition_id)) def _make_data_block_rep(self, partition_id, data_block_fname, check_existed): try: rep = DataBlockRep(self._data_source.data_source_meta.name, data_block_fname, partition_id, self._partition_data_block_dir(partition_id), check_existed) except Exception as e: # pylint: disable=broad-except logging.warning("Failed to create data block rep for %s in"\ "partition %d reason %s", data_block_fname, partition_id, e) return None return rep def _data_source_name(self): return self._data_source.data_source_meta.name def _sync_raw_data_manifest(self, partition_id): etcd_key = partition_manifest_etcd_key(self._data_source_name(), partition_id) data = self._etcd.get_data(etcd_key) assert data is not None, "raw data manifest of partition "\ "{} must be existed".format(partition_id) return text_format.Parse(data, dj_pb.RawDataManifest()) def _filter_by_visible(self, index, manifest): join_state = manifest.join_example_rep.state if self._data_source.role == common_pb.FLRole.Follower and \ join_state != dj_pb.JoinExampleState.Joined: return index > manifest.peer_dumped_index return False
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir = 'dp_test' data_portal_name = 'test_data_source' etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, True) etcd.delete_prefix(etcd_base_dir) portal_input_base_dir='./portal_upload_dir' portal_output_base_dir='./portal_output_dir' raw_data_publish_dir = 'raw_data_publish_dir' portal_manifest = dp_pb.DataPortalManifest( name=data_portal_name, data_portal_type=dp_pb.DataPortalType.Streaming, output_partition_num=4, input_file_wildcard="*.done", input_base_dir=portal_input_base_dir, output_base_dir=portal_output_base_dir, raw_data_publish_dir=raw_data_publish_dir, processing_job_id=-1, next_job_id=0 ) etcd.set_data(common.portal_etcd_base_dir(data_portal_name), text_format.MessageToString(portal_manifest)) if gfile.Exists(portal_input_base_dir): gfile.DeleteRecursively(portal_input_base_dir) gfile.MakeDirs(portal_input_base_dir) all_fnames = ['{}.done'.format(i) for i in range(100)] all_fnames.append('{}.xx'.format(100)) for fname in all_fnames: fpath = os.path.join(portal_input_base_dir, fname) with gfile.Open(fpath, "w") as f: f.write('xxx') portal_master_addr = 'localhost:4061' portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False ) data_portal_master = DataPortalMasterService( int(portal_master_addr.split(':')[1]), data_portal_name, etcd_name, etcd_base_dir, etcd_addrs, portal_options ) data_portal_master.start() channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL) portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel) recv_manifest = portal_master_cli.GetDataPortalManifest(empty_pb2.Empty()) self.assertEqual(recv_manifest.name, portal_manifest.name) self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type) self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num) self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard) self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir) self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir) self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir) self.assertEqual(recv_manifest.next_job_id, 1) self.assertEqual(recv_manifest.processing_job_id, 0) self._check_portal_job(etcd, all_fnames, portal_manifest, 0) mapped_partition = set() task_0 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) task_0_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_0, task_0_1) self.assertTrue(task_0.HasField('map_task')) mapped_partition.add(task_0.map_task.partition_id) self._check_map_task(task_0.map_task, all_fnames, task_0.map_task.partition_id, portal_manifest) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_0.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) task_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertTrue(task_1.HasField('map_task')) mapped_partition.add(task_1.map_task.partition_id) self._check_map_task(task_1.map_task, all_fnames, task_1.map_task.partition_id, portal_manifest) task_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_2.HasField('map_task')) mapped_partition.add(task_2.map_task.partition_id) self._check_map_task(task_2.map_task, all_fnames, task_2.map_task.partition_id, portal_manifest) task_3 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_3.HasField('map_task')) mapped_partition.add(task_3.map_task.partition_id) self._check_map_task(task_3.map_task, all_fnames, task_3.map_task.partition_id, portal_manifest) self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_1.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) pending_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=4)) self.assertTrue(pending_1.HasField('pending')) pending_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(pending_2.HasField('pending')) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_2.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_3.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) reduce_partition = set() task_4 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) task_4_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_4, task_4_1) self.assertTrue(task_4.HasField('reduce_task')) reduce_partition.add(task_4.reduce_task.partition_id) self._check_reduce_task(task_4.reduce_task, task_4.reduce_task.partition_id, portal_manifest) task_5 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_5.HasField('reduce_task')) reduce_partition.add(task_5.reduce_task.partition_id) self._check_reduce_task(task_5.reduce_task, task_5.reduce_task.partition_id, portal_manifest) task_6 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_6.HasField('reduce_task')) reduce_partition.add(task_6.reduce_task.partition_id) self._check_reduce_task(task_6.reduce_task, task_6.reduce_task.partition_id, portal_manifest) task_7= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(task_7.HasField('reduce_task')) reduce_partition.add(task_7.reduce_task.partition_id) self.assertEqual(len(reduce_partition), 4) self._check_reduce_task(task_7.reduce_task, task_7.reduce_task.partition_id, portal_manifest) task_8= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_8.HasField('pending')) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_4.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_5.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_6.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=3, partition_id=task_7.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) task_9= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_9.HasField('finished')) data_portal_master.stop() gfile.DeleteRecursively(portal_input_base_dir)
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 64 data_source_meta.max_matching_window = 128 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) customized_options.set_use_mock_etcd() master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = master_client_l.GetDataSourceState( data_source_l.data_source_meta) rsp_f = master_client_f.GetDataSourceState( data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.total_index = 1 << 13
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) etcd_f.delete_prefix( common.data_source_etcd_base_dir(data_source_name)) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_l, data_source_l) data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(etcd_f, data_source_f) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.WriterOptions( output_writer='TF_RECORD')) self.total_index = 1 << 13
def __init__(self, peer_client, data_source_name, etcd_name, etcd_addrs, etcd_base_dir): super(DataJoinMaster, self).__init__() self._data_source_name = data_source_name etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir) self._fsm = MasterFSM(peer_client, data_source_name, etcd)
class RsaPsi(unittest.TestCase): def _setUpEtcd(self): self._etcd_name = 'test_etcd' self._etcd_addrs = 'localhost:2379' self._etcd_base_dir_l = 'byefl_l' self._etcd_base_dir_f = 'byefl_f' self._etcd_l = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) self._etcd_f = EtcdClient(self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) def _setUpDataSource(self): self._data_source_name = 'test_data_source' self._etcd_l.delete_prefix(self._data_source_name) self._etcd_f.delete_prefix(self._data_source_name) self._data_source_l = common_pb.DataSource() self._data_source_l.role = common_pb.FLRole.Leader self._data_source_l.state = common_pb.DataSourceState.Init self._data_source_l.data_block_dir = "./data_block_l" self._data_source_l.raw_data_dir = "./raw_data_l" self._data_source_l.example_dumped_dir = "./example_dumped_l" self._data_source_l.raw_data_sub_dir = "./raw_data_sub_dir_l" self._data_source_f = common_pb.DataSource() self._data_source_f.role = common_pb.FLRole.Follower self._data_source_f.state = common_pb.DataSourceState.Init self._data_source_f.data_block_dir = "./data_block_f" self._data_source_f.raw_data_dir = "./raw_data_f" self._data_source_f.example_dumped_dir = "./example_dumped_f" self._data_source_f.raw_data_sub_dir = "./raw_data_sub_dir_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = self._data_source_name data_source_meta.partition_num = 4 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 self._data_source_l.data_source_meta.MergeFrom(data_source_meta) self._data_source_f.data_source_meta.MergeFrom(data_source_meta) common.commit_data_source(self._etcd_l, self._data_source_l) common.commit_data_source(self._etcd_f, self._data_source_f) def _generate_input_csv(self, cands, base_dir): if not gfile.Exists(base_dir): gfile.MakeDirs(base_dir) fpaths = [] random.shuffle(cands) csv_writers = [] partition_num = self._data_source_l.data_source_meta.partition_num for partition_id in range(partition_num): fpath = os.path.join(base_dir, str(partition_id) + '.rd') fpaths.append(fpath) csv_writers.append(csv_dict_writer.CsvDictWriter(fpath)) for item in cands: partition_id = CityHash32(item) % partition_num raw = OrderedDict() raw['raw_id'] = item raw['feat_0'] = str((partition_id << 30) + 0) + item raw['feat_1'] = str((partition_id << 30) + 1) + item raw['feat_2'] = str((partition_id << 30) + 2) + item csv_writers[partition_id].write(raw) for csv_writer in csv_writers: csv_writer.close() return fpaths def _setUpRsaPsiConf(self): self._input_dir_l = './rsa_psi_raw_input_l' self._input_dir_f = './rsa_psi_raw_input_f' self._pre_processor_ouput_dir_l = './pre_processor_output_dir_l' self._pre_processor_ouput_dir_f = './pre_processor_output_dir_f' key_dir = path.join(path.dirname(path.abspath(__file__)), '../rsa_key') self._rsa_public_key_path = path.join(key_dir, 'rsa_psi.pub') self._rsa_private_key_path = path.join(key_dir, 'rsa_psi') self._raw_data_pub_dir_l = self._data_source_l.raw_data_sub_dir self._raw_data_pub_dir_f = self._data_source_f.raw_data_sub_dir def _gen_psi_input_raw_data(self): self._intersection_ids = set( ['{:09}'.format(i) for i in range(0, 1 << 16) if i % 3 == 0]) self._rsa_raw_id_l = set([ '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 0 ]) | self._intersection_ids self._rsa_raw_id_f = set([ '{:09}'.format(i) for i in range(0, 1 << 16) if i % 2 == 1 ]) | self._intersection_ids self._input_dir_l = './rsa_psi_raw_input_l' self._input_dir_f = './rsa_psi_raw_input_f' self._psi_raw_data_fpaths_l = self._generate_input_csv( list(self._rsa_raw_id_l), self._input_dir_l) self._psi_raw_data_fpaths_f = self._generate_input_csv( list(self._rsa_raw_id_f), self._input_dir_f) def _remove_existed_dir(self): if gfile.Exists(self._input_dir_l): gfile.DeleteRecursively(self._input_dir_l) if gfile.Exists(self._input_dir_f): gfile.DeleteRecursively(self._input_dir_f) if gfile.Exists(self._pre_processor_ouput_dir_l): gfile.DeleteRecursively(self._pre_processor_ouput_dir_l) if gfile.Exists(self._pre_processor_ouput_dir_f): gfile.DeleteRecursively(self._pre_processor_ouput_dir_f) if gfile.Exists(self._data_source_l.data_block_dir): gfile.DeleteRecursively(self._data_source_l.data_block_dir) if gfile.Exists(self._data_source_l.raw_data_dir): gfile.DeleteRecursively(self._data_source_l.raw_data_dir) if gfile.Exists(self._data_source_l.example_dumped_dir): gfile.DeleteRecursively(self._data_source_l.example_dumped_dir) if gfile.Exists(self._data_source_f.data_block_dir): gfile.DeleteRecursively(self._data_source_f.data_block_dir) if gfile.Exists(self._data_source_f.raw_data_dir): gfile.DeleteRecursively(self._data_source_f.raw_data_dir) if gfile.Exists(self._data_source_f.example_dumped_dir): gfile.DeleteRecursively(self._data_source_f.example_dumped_dir) def _launch_masters(self): self._master_addr_l = 'localhost:4061' self._master_addr_f = 'localhost:4062' master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) self._master_l = data_join_master.DataJoinMasterService( int(self._master_addr_l.split(':')[1]), self._master_addr_f, self._data_source_name, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, master_options) self._master_f = data_join_master.DataJoinMasterService( int(self._master_addr_f.split(':')[1]), self._master_addr_l, self._data_source_name, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, master_options) self._master_f.start() self._master_l.start() channel_l = make_insecure_channel(self._master_addr_l, ChannelType.INTERNAL) self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(self._master_addr_f, ChannelType.INTERNAL) self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) logging.info("masters turn into Processing state") def _launch_workers(self): worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='CSV_DICT', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='SORT_RUN_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=4096), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='CSV_DICT_DATABLOCK_BUILDER')) self._worker_addrs_l = [ 'localhost:4161', 'localhost:4162', 'localhost:4163', 'localhost:4164' ] self._worker_addrs_f = [ 'localhost:5161', 'localhost:5162', 'localhost:5163', 'localhost:5164' ] self._workers_l = [] self._workers_f = [] for rank_id in range(4): worker_addr_l = self._worker_addrs_l[rank_id] worker_addr_f = self._worker_addrs_f[rank_id] self._workers_l.append( data_join_worker.DataJoinWorkerService( int(worker_addr_l.split(':')[1]), worker_addr_f, self._master_addr_l, rank_id, self._etcd_name, self._etcd_base_dir_l, self._etcd_addrs, worker_options)) self._workers_f.append( data_join_worker.DataJoinWorkerService( int(worker_addr_f.split(':')[1]), worker_addr_l, self._master_addr_f, rank_id, self._etcd_name, self._etcd_base_dir_f, self._etcd_addrs, worker_options)) for w in self._workers_l: w.start() for w in self._workers_f: w.start() def _launch_rsa_psi_signer(self): self._rsa_psi_signer_addr = 'localhost:6171' rsa_private_key_pem = None with gfile.GFile(self._rsa_private_key_path, 'rb') as f: rsa_private_key_pem = f.read() rsa_private_key = rsa.PrivateKey.load_pkcs1(rsa_private_key_pem) self._rsa_psi_signer = rsa_psi_signer.RsaPsiSigner( rsa_private_key, 1, 500) self._rsa_psi_signer.start( int(self._rsa_psi_signer_addr.split(':')[1]), 512) def _stop_workers(self): for w in self._workers_f: w.stop() for w in self._workers_l: w.stop() def _stop_masters(self): self._master_f.stop() self._master_l.stop() def _stop_rsa_psi_signer(self): self._rsa_psi_signer.stop() def setUp(self): self._setUpEtcd() self._setUpDataSource() self._setUpRsaPsiConf() self._remove_existed_dir() self._gen_psi_input_raw_data() self._launch_masters() self._launch_workers() self._launch_rsa_psi_signer() def _preprocess_rsa_psi_leader(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_private_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_l.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='leader-rsa-psi-processor', role=common_pb.FLRole.Leader, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_l[partition_id]], output_file_dir=self._pre_processor_ouput_dir_l, raw_data_publish_dir=self._raw_data_pub_dir_l, partition_id=partition_id, offload_processor_number=1, max_flying_sign_batch=128, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_l, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished() def _preprocess_rsa_psi_follower(self): processors = [] rsa_key_pem = None with gfile.GFile(self._rsa_public_key_path, 'rb') as f: rsa_key_pem = f.read() for partition_id in range( self._data_source_f.data_source_meta.partition_num): options = dj_pb.RsaPsiPreProcessorOptions( preprocessor_name='follower-rsa-psi-processor', role=common_pb.FLRole.Follower, rsa_key_pem=rsa_key_pem, input_file_paths=[self._psi_raw_data_fpaths_f[partition_id]], output_file_dir=self._pre_processor_ouput_dir_f, raw_data_publish_dir=self._raw_data_pub_dir_f, partition_id=partition_id, leader_rsa_psi_signer_addr=self._rsa_psi_signer_addr, offload_processor_number=1, max_flying_sign_batch=128, max_flying_sign_rpc=64, sign_rpc_timeout_ms=100000, stub_fanout=2, slow_sign_threshold=8, sort_run_merger_read_ahead_buffer=1 << 20, rpc_sync_mode=True if partition_id % 2 == 0 else False, rpc_thread_pool_size=16, batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=1024, max_flying_item=1 << 14)) processor = rsa_psi_preprocessor.RsaPsiPreProcessor( options, self._etcd_name, self._etcd_addrs, self._etcd_base_dir_f, True) processor.start_process() processors.append(processor) for processor in processors: processor.wait_for_finished() def test_all_pipeline(self): start_tm = time.time() self._preprocess_rsa_psi_follower() logging.warning("Follower Preprocess cost %d seconds", time.time() - start_tm) start_tm = time.time() self._preprocess_rsa_psi_leader() logging.warning("Leader Preprocess cost %f seconds", time.time() - start_tm) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=self._data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=self._data_source_f.data_source_meta) dss_l = self._master_client_l.GetDataSourceStatus(req_l) dss_f = self._master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) logging.info("masters turn into Finished state") def tearDown(self): self._stop_workers() self._stop_masters() self._stop_rsa_psi_signer() self._remove_existed_dir()