def _cheif_barriar(self, is_chief=False, sync_times=300): worker_replicas = os.environ.get('REPLICA_NUM', 0) etcd_client = EtcdClient(os.environ['ETCD_CLUSTER'], os.environ['ETCD_ADDRESS'], SYNC_PATH) sync_path = '%s/%s' % (os.environ['APPLICATION_ID'], os.environ['WORKER_RANK']) logging.info('Creating a sync flag at %s', sync_path) etcd_client.set_data(sync_path, 1) if is_chief: for _ in range(sync_times): sync_list = etcd_client.get_prefix_kvs( os.environ['APPLICATION_ID']) logging.info('Sync file pattern is: %s', sync_list) if len(sync_list) < worker_replicas: logging.info('Count of ready workers is %d', len(sync_list)) time.sleep(6) else: break
def test_api(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 32 data_source_meta.max_matching_window = 1024 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertFalse(rdrsp.HasField('manifest')) self.assertFalse(rdrsp.HasField('finished')) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(frreq) self.assertEqual(frrsp.code, 0) rdrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=-1)) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.state, dj_pb.Joining) self.assertEqual(rdrsp.manifest.allocated_rank_id, 0) self.assertEqual(rdrsp.manifest.partition_id, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frrsp = client_l.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) frreq = dj_pb.FinishRawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, join_example=dj_pb.JoinExampleRequest(partition_id=0)) frrsp = client_f.FinishJoinPartition(rdreq) self.assertEqual(frrsp.code, 0) while True: rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta) rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Finished and rsp_f.state == common_pb.DataSourceState.Finished): break else: time.sleep(2) master_l.stop() master_f.stop()
help='the namespace of etcd key') parser.add_argument('--raw_data_sub_dir', type=str, required=True, help='the etcd base dir to subscribe new raw data') args = parser.parse_args() data_source = common_pb.DataSource() data_source.data_source_meta.name = args.data_source_name data_source.data_source_meta.partition_num = args.partition_num data_source.data_source_meta.start_time = args.start_time data_source.data_source_meta.end_time = args.end_time data_source.data_source_meta.negative_sampling_rate = \ args.negative_sampling_rate if args.role == 'leader': data_source.role = common_pb.FLRole.Leader else: assert args.role == 'follower' data_source.role = common_pb.FLRole.Follower data_source.example_dumped_dir = args.example_dump_dir data_source.data_block_dir = args.data_block_dir data_source.raw_data_sub_dir = args.raw_data_sub_dir data_source.state = common_pb.DataSourceState.Init etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir) master_etcd_key = os.path.join(data_source.data_source_meta.name, 'master') raw_data = etcd.get_data(master_etcd_key) if raw_data is None: logging.info("data source %s is not existed", args.data_source_name) etcd.set_data(master_etcd_key, text_format.MessageToString(data_source)) logging.info("apply new data source %s", args.data_source_name) else: logging.info("data source %s has been existed", args.data_source_name) etcd.destory_client_pool()
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir = 'dp_test' data_portal_name = 'test_data_source' etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir, True) etcd.delete_prefix(etcd_base_dir) portal_input_base_dir='./portal_upload_dir' portal_output_base_dir='./portal_output_dir' raw_data_publish_dir = 'raw_data_publish_dir' portal_manifest = dp_pb.DataPortalManifest( name=data_portal_name, data_portal_type=dp_pb.DataPortalType.Streaming, output_partition_num=4, input_file_wildcard="*.done", input_base_dir=portal_input_base_dir, output_base_dir=portal_output_base_dir, raw_data_publish_dir=raw_data_publish_dir, processing_job_id=-1, next_job_id=0 ) etcd.set_data(common.portal_etcd_base_dir(data_portal_name), text_format.MessageToString(portal_manifest)) if gfile.Exists(portal_input_base_dir): gfile.DeleteRecursively(portal_input_base_dir) gfile.MakeDirs(portal_input_base_dir) all_fnames = ['{}.done'.format(i) for i in range(100)] all_fnames.append('{}.xx'.format(100)) for fname in all_fnames: fpath = os.path.join(portal_input_base_dir, fname) with gfile.Open(fpath, "w") as f: f.write('xxx') portal_master_addr = 'localhost:4061' portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False ) data_portal_master = DataPortalMasterService( int(portal_master_addr.split(':')[1]), data_portal_name, etcd_name, etcd_base_dir, etcd_addrs, portal_options ) data_portal_master.start() channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL) portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel) recv_manifest = portal_master_cli.GetDataPortalManifest(empty_pb2.Empty()) self.assertEqual(recv_manifest.name, portal_manifest.name) self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type) self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num) self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard) self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir) self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir) self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir) self.assertEqual(recv_manifest.next_job_id, 1) self.assertEqual(recv_manifest.processing_job_id, 0) self._check_portal_job(etcd, all_fnames, portal_manifest, 0) mapped_partition = set() task_0 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) task_0_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_0, task_0_1) self.assertTrue(task_0.HasField('map_task')) mapped_partition.add(task_0.map_task.partition_id) self._check_map_task(task_0.map_task, all_fnames, task_0.map_task.partition_id, portal_manifest) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_0.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) task_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertTrue(task_1.HasField('map_task')) mapped_partition.add(task_1.map_task.partition_id) self._check_map_task(task_1.map_task, all_fnames, task_1.map_task.partition_id, portal_manifest) task_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_2.HasField('map_task')) mapped_partition.add(task_2.map_task.partition_id) self._check_map_task(task_2.map_task, all_fnames, task_2.map_task.partition_id, portal_manifest) task_3 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_3.HasField('map_task')) mapped_partition.add(task_3.map_task.partition_id) self._check_map_task(task_3.map_task, all_fnames, task_3.map_task.partition_id, portal_manifest) self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_1.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) pending_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=4)) self.assertTrue(pending_1.HasField('pending')) pending_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(pending_2.HasField('pending')) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_2.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_3.map_task.partition_id, part_state=dp_pb.PartState.kIdMap) ) reduce_partition = set() task_4 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) task_4_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_4, task_4_1) self.assertTrue(task_4.HasField('reduce_task')) reduce_partition.add(task_4.reduce_task.partition_id) self._check_reduce_task(task_4.reduce_task, task_4.reduce_task.partition_id, portal_manifest) task_5 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_5.HasField('reduce_task')) reduce_partition.add(task_5.reduce_task.partition_id) self._check_reduce_task(task_5.reduce_task, task_5.reduce_task.partition_id, portal_manifest) task_6 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_6.HasField('reduce_task')) reduce_partition.add(task_6.reduce_task.partition_id) self._check_reduce_task(task_6.reduce_task, task_6.reduce_task.partition_id, portal_manifest) task_7= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(task_7.HasField('reduce_task')) reduce_partition.add(task_7.reduce_task.partition_id) self.assertEqual(len(reduce_partition), 4) self._check_reduce_task(task_7.reduce_task, task_7.reduce_task.partition_id, portal_manifest) task_8= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_8.HasField('pending')) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_4.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_5.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_6.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) portal_master_cli.FinishTask(dp_pb.FinishTaskRequest( rank_id=3, partition_id=task_7.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce) ) task_9= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_9.HasField('finished')) data_portal_master.stop() gfile.DeleteRecursively(portal_input_base_dir)
def test_api(self): logging.getLogger().setLevel(logging.DEBUG) etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f= 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 1 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, options ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, options ) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta ) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=-1, join_example=empty_pb2.Empty() ) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty() ) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertTrue(rdrsp.status.code == 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0) self.assertEqual(rdrsp.manifest.join_example_rep.state, dj_pb.JoinExampleState.Joining) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=-1, sync_example_id=empty_pb2.Empty() ) rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_l.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=1, partition_id = 0, sync_example_id=empty_pb2.Empty() ) rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) #check idempotent rdrsp = client_f.RequestJoinPartition(rdreq) self.assertEqual(rdrsp.status.code, 0) self.assertTrue(rdrsp.HasField('manifest')) self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1) self.assertEqual(rdrsp.manifest.sync_example_id_rep.state, dj_pb.SyncExampleIdState.Syncing) self.assertEqual(rdrsp.manifest.partition_id, 0) rdreq1 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=1, partition_id=0, sync_example_id=empty_pb2.Empty() ) try: rsp = client_l.FinishJoinPartition(rdreq1) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq2 = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, join_example=empty_pb2.Empty() ) try: rsp = client_l.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, ) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=False, file_paths=['a'] ) ) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=False, file_paths=['b'] ) ) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=True, file_paths=['a', 'b'] ) ) rsp = client_l.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertFalse(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, ) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 0) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=False, file_paths=['a'] ) ) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 1) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=False, file_paths=['b'] ) ) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=True, file_paths=['a', 'b'] ) ) rsp = client_f.AddRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertFalse(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty() ) rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_l = client_l.QueryRawDataManifest(rdreq) self.assertTrue(manifest_l is not None) self.assertTrue(manifest_l.finished) self.assertEqual(manifest_l.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_l.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=True, file_paths=['x'] ) ) try: rsp = client_l.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq1) self.assertEqual(rsp.code, 0) try: rsp = client_f.FinishJoinPartition(rdreq2) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, finish_raw_data=empty_pb2.Empty() ) rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishRawData(rdreq) self.assertEqual(rsp.code, 0) manifest_f = client_f.QueryRawDataManifest(rdreq) self.assertTrue(manifest_f is not None) self.assertTrue(manifest_f.finished) self.assertEqual(manifest_f.next_process_index, 2) rdreq = dj_pb.RawDataRequest( data_source_meta=data_source_f.data_source_meta, rank_id=0, partition_id=0, raw_data_fpaths=dj_pb.RawDataFilePaths( dedup=True, file_paths=['x'] ) ) try: rsp = client_f.AddRawData(rdreq) except Exception as e: self.assertTrue(True) else: self.assertTrue(False) rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_f.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) #check idempotent rsp = client_l.FinishJoinPartition(rdreq2) self.assertEqual(rsp.code, 0) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta ) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta ) dss_l = client_l.GetDataSourceStatus(req_l) dss_f = client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Finished and \ dss_f.state == common_pb.DataSourceState.Finished: break else: time.sleep(2) master_l.stop() master_f.stop()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() data_source_f.role = common_pb.FLRole.Follower data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_meta.min_matching_window = 64 data_source_meta.max_matching_window = 128 data_source_meta.data_source_type = common_pb.DataSourceType.Sequential data_source_meta.max_example_in_data_block = 1000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) customized_options.set_use_mock_etcd() master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: rsp_l = master_client_l.GetDataSourceState( data_source_l.data_source_meta) rsp_f = master_client_f.GetDataSourceState( data_source_f.data_source_meta) self.assertEqual(rsp_l.status.code, 0) self.assertEqual(rsp_l.role, common_pb.FLRole.Leader) self.assertEqual(rsp_l.data_source_type, common_pb.DataSourceType.Sequential) self.assertEqual(rsp_f.status.code, 0) self.assertEqual(rsp_f.role, common_pb.FLRole.Follower) self.assertEqual(rsp_f.data_source_type, common_pb.DataSourceType.Sequential) if (rsp_l.state == common_pb.DataSourceState.Processing and rsp_f.state == common_pb.DataSourceState.Processing): break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.total_index = 1 << 13
parser.add_argument('--long_running', action='store_true', help='make the data portal long running') args = parser.parse_args() etcd = EtcdClient(args.etcd_name, args.etcd_addrs, args.etcd_base_dir, args.use_mock_etcd) etcd_key = common.portal_etcd_base_dir(args.data_portal_name) if etcd.get_data(etcd_key) is None: portal_manifest = dp_pb.DataPortalManifest( name=args.data_portal_name, data_portal_type=(dp_pb.DataPortalType.PSI if args.data_portal_type == 'PSI' else dp_pb.DataPortalType.Streaming), output_partition_num=args.output_partition_num, input_file_wildcard=args.input_file_wildcard, input_base_dir=args.input_base_dir, output_base_dir=args.output_base_dir, raw_data_publish_dir=args.raw_data_publish_dir, processing_job_id=-1) etcd.set_data(etcd_key, text_format.MessageToString(portal_manifest)) options = dp_pb.DataPotraMasterlOptions(use_mock_etcd=args.use_mock_etcd, long_running=args.long_running) portal_master_srv = DataPortalMasterService(args.listen_port, args.data_portal_name, args.etcd_name, args.etcd_base_dir, args.etcd_addrs, options) portal_master_srv.run()
def setUp(self): etcd_name = 'test_etcd' etcd_addrs = 'localhost:2379' etcd_base_dir_l = 'byefl_l' etcd_base_dir_f = 'byefl_f' data_source_name = 'test_data_source' etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True) etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True) etcd_l.delete_prefix(data_source_name) etcd_f.delete_prefix(data_source_name) data_source_l = common_pb.DataSource() self.raw_data_pub_dir_l = './raw_data_pub_dir_l' data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l data_source_l.role = common_pb.FLRole.Leader data_source_l.state = common_pb.DataSourceState.Init data_source_l.data_block_dir = "./data_block_l" data_source_l.raw_data_dir = "./raw_data_l" data_source_l.example_dumped_dir = "./example_dumped_l" data_source_f = common_pb.DataSource() self.raw_data_pub_dir_f = './raw_data_pub_dir_f' data_source_f.role = common_pb.FLRole.Follower data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f data_source_f.state = common_pb.DataSourceState.Init data_source_f.data_block_dir = "./data_block_f" data_source_f.raw_data_dir = "./raw_data_f" data_source_f.example_dumped_dir = "./example_dumped_f" data_source_meta = common_pb.DataSourceMeta() data_source_meta.name = data_source_name data_source_meta.partition_num = 2 data_source_meta.start_time = 0 data_source_meta.end_time = 100000000 data_source_l.data_source_meta.MergeFrom(data_source_meta) etcd_l.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_l)) data_source_f.data_source_meta.MergeFrom(data_source_meta) etcd_f.set_data(os.path.join(data_source_name, 'master'), text_format.MessageToString(data_source_f)) master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True) master_addr_l = 'localhost:4061' master_addr_f = 'localhost:4062' master_l = data_join_master.DataJoinMasterService( int(master_addr_l.split(':')[1]), master_addr_f, data_source_name, etcd_name, etcd_base_dir_l, etcd_addrs, master_options, ) master_l.start() master_f = data_join_master.DataJoinMasterService( int(master_addr_f.split(':')[1]), master_addr_l, data_source_name, etcd_name, etcd_base_dir_f, etcd_addrs, master_options) master_f.start() channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL) master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l) channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL) master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f) while True: req_l = dj_pb.DataSourceRequest( data_source_meta=data_source_l.data_source_meta) req_f = dj_pb.DataSourceRequest( data_source_meta=data_source_f.data_source_meta) dss_l = master_client_l.GetDataSourceStatus(req_l) dss_f = master_client_f.GetDataSourceStatus(req_f) self.assertEqual(dss_l.role, common_pb.FLRole.Leader) self.assertEqual(dss_f.role, common_pb.FLRole.Follower) if dss_l.state == common_pb.DataSourceState.Processing and \ dss_f.state == common_pb.DataSourceState.Processing: break else: time.sleep(2) self.master_client_l = master_client_l self.master_client_f = master_client_f self.master_addr_l = master_addr_l self.master_addr_f = master_addr_f self.etcd_l = etcd_l self.etcd_f = etcd_f self.data_source_l = data_source_l self.data_source_f = data_source_f self.master_l = master_l self.master_f = master_f self.data_source_name = data_source_name, self.etcd_name = etcd_name self.etcd_addrs = etcd_addrs self.etcd_base_dir_l = etcd_base_dir_l self.etcd_base_dir_f = etcd_base_dir_f self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher( self.etcd_l, self.raw_data_pub_dir_l) self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher( self.etcd_f, self.raw_data_pub_dir_f) if gfile.Exists(data_source_l.data_block_dir): gfile.DeleteRecursively(data_source_l.data_block_dir) if gfile.Exists(data_source_l.example_dumped_dir): gfile.DeleteRecursively(data_source_l.example_dumped_dir) if gfile.Exists(data_source_l.raw_data_dir): gfile.DeleteRecursively(data_source_l.raw_data_dir) if gfile.Exists(data_source_f.data_block_dir): gfile.DeleteRecursively(data_source_f.data_block_dir) if gfile.Exists(data_source_f.example_dumped_dir): gfile.DeleteRecursively(data_source_f.example_dumped_dir) if gfile.Exists(data_source_f.raw_data_dir): gfile.DeleteRecursively(data_source_f.raw_data_dir) self.worker_options = dj_pb.DataJoinWorkerOptions( use_mock_etcd=True, raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD', compressed_type=''), example_id_dump_options=dj_pb.ExampleIdDumpOptions( example_id_dump_interval=1, example_id_dump_threshold=1024), example_joiner_options=dj_pb.ExampleJoinerOptions( example_joiner='STREAM_JOINER', min_matching_window=64, max_matching_window=256, data_block_dump_interval=30, data_block_dump_threshold=1000), batch_processor_options=dj_pb.BatchProcessorOptions( batch_size=512, max_flying_item=2048), data_block_builder_options=dj_pb.DataBlockBuilderOptions( data_block_builder='TF_RECORD_DATABLOCK_BUILDER')) self.total_index = 1 << 13