def test_api(self): logging.getLogger().setLevel(logging.DEBUG) kvstore_type = 'etcd' db_base_dir = 'dp_test' os.environ['ETCD_BASE_DIR'] = db_base_dir data_portal_name = 'test_data_source' kvstore = DBClient(kvstore_type, True) kvstore.delete_prefix(db_base_dir) portal_input_base_dir = './portal_upload_dir' portal_output_base_dir = './portal_output_dir' raw_data_publish_dir = 'raw_data_publish_dir' portal_manifest = dp_pb.DataPortalManifest( name=data_portal_name, data_portal_type=dp_pb.DataPortalType.Streaming, output_partition_num=4, input_file_wildcard="*.done", input_base_dir=portal_input_base_dir, output_base_dir=portal_output_base_dir, raw_data_publish_dir=raw_data_publish_dir, processing_job_id=-1, next_job_id=0) kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name), text_format.MessageToString(portal_manifest)) if gfile.Exists(portal_input_base_dir): gfile.DeleteRecursively(portal_input_base_dir) gfile.MakeDirs(portal_input_base_dir) all_fnames = ['1001/{}.done'.format(i) for i in range(100)] all_fnames.append('{}.xx'.format(100)) all_fnames.append('1001/_SUCCESS') for fname in all_fnames: fpath = os.path.join(portal_input_base_dir, fname) gfile.MakeDirs(os.path.dirname(fpath)) with gfile.Open(fpath, "w") as f: f.write('xxx') portal_master_addr = 'localhost:4061' portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=True, ) data_portal_master = DataPortalMasterService( int(portal_master_addr.split(':')[1]), data_portal_name, kvstore_type, portal_options) data_portal_master.start() channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL) portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel) recv_manifest = portal_master_cli.GetDataPortalManifest( empty_pb2.Empty()) self.assertEqual(recv_manifest.name, portal_manifest.name) self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type) self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num) self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard) self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir) self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir) self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir) self.assertEqual(recv_manifest.next_job_id, 1) self.assertEqual(recv_manifest.processing_job_id, 0) self._check_portal_job(kvstore, all_fnames, portal_manifest, 0) mapped_partition = set() task_0 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) task_0_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_0, task_0_1) self.assertTrue(task_0.HasField('map_task')) mapped_partition.add(task_0.map_task.partition_id) self._check_map_task(task_0.map_task, all_fnames, task_0.map_task.partition_id, portal_manifest) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=0, partition_id=task_0.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) task_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertTrue(task_1.HasField('map_task')) mapped_partition.add(task_1.map_task.partition_id) self._check_map_task(task_1.map_task, all_fnames, task_1.map_task.partition_id, portal_manifest) task_2 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_2.HasField('map_task')) mapped_partition.add(task_2.map_task.partition_id) self._check_map_task(task_2.map_task, all_fnames, task_2.map_task.partition_id, portal_manifest) task_3 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_3.HasField('map_task')) mapped_partition.add(task_3.map_task.partition_id) self._check_map_task(task_3.map_task, all_fnames, task_3.map_task.partition_id, portal_manifest) self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=0, partition_id=task_1.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) pending_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=4)) self.assertTrue(pending_1.HasField('pending')) pending_2 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(pending_2.HasField('pending')) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=1, partition_id=task_2.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest(rank_id=2, partition_id=task_3.map_task.partition_id, part_state=dp_pb.PartState.kIdMap)) reduce_partition = set() task_4 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) task_4_1 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=0)) self.assertEqual(task_4, task_4_1) self.assertTrue(task_4.HasField('reduce_task')) reduce_partition.add(task_4.reduce_task.partition_id) self._check_reduce_task(task_4.reduce_task, task_4.reduce_task.partition_id, portal_manifest) task_5 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=1)) self.assertTrue(task_5.HasField('reduce_task')) reduce_partition.add(task_5.reduce_task.partition_id) self._check_reduce_task(task_5.reduce_task, task_5.reduce_task.partition_id, portal_manifest) task_6 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=2)) self.assertTrue(task_6.HasField('reduce_task')) reduce_partition.add(task_6.reduce_task.partition_id) self._check_reduce_task(task_6.reduce_task, task_6.reduce_task.partition_id, portal_manifest) task_7 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=3)) self.assertTrue(task_7.HasField('reduce_task')) reduce_partition.add(task_7.reduce_task.partition_id) self.assertEqual(len(reduce_partition), 4) self._check_reduce_task(task_7.reduce_task, task_7.reduce_task.partition_id, portal_manifest) task_8 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_8.HasField('pending')) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=0, partition_id=task_4.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=1, partition_id=task_5.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=2, partition_id=task_6.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) portal_master_cli.FinishTask( dp_pb.FinishTaskRequest( rank_id=3, partition_id=task_7.reduce_task.partition_id, part_state=dp_pb.PartState.kEventTimeReduce)) task_9 = portal_master_cli.RequestNewTask( dp_pb.NewTaskRequest(rank_id=5)) self.assertTrue(task_9.HasField('finished')) data_portal_master.stop() gfile.DeleteRecursively(portal_input_base_dir)
import etcd3 from fedlearner.common.db_client import DBClient from fedlearner.common.db_client import get_kvstore_config MySQL_client = DBClient('mysql') database, addr, username, password, base_dir = \ get_kvstore_config('etcd') (host, port) = addr.split(':') options = [('grpc.max_send_message_length', 2**31 - 1), ('grpc.max_receive_message_length', 2**31 - 1)] clnt = etcd3.client(host=host, port=port, grpc_options=options) for (data, key) in clnt.get_prefix('/', sort_order='ascend'): if not isinstance(key.key, str): key = key.key.decoder() if not isinstance(data, str): data = data.decoder() MySQL_client.set_data(key, data)
kvstore = DBClient(args.kvstore_type, use_mock_etcd) kvstore_key = common.portal_kvstore_base_dir(args.data_portal_name) portal_manifest = kvstore.get_data(kvstore_key) data_portal_type = dp_pb.DataPortalType.PSI if \ args.data_portal_type == 'PSI' else dp_pb.DataPortalType.Streaming if portal_manifest is None: portal_manifest = dp_pb.DataPortalManifest( name=args.data_portal_name, data_portal_type=data_portal_type, output_partition_num=args.output_partition_num, input_file_wildcard=args.input_file_wildcard, input_base_dir=args.input_base_dir, output_base_dir=args.output_base_dir, raw_data_publish_dir=args.raw_data_publish_dir, processing_job_id=-1) kvstore.set_data(kvstore_key, text_format.\ MessageToString(portal_manifest)) else: # validation parameter consistency passed = True portal_manifest = \ text_format.Parse(portal_manifest, dp_pb.DataPortalManifest(), allow_unknown_field=True) parameter_pairs = [ (portal_manifest.data_portal_type, data_portal_type), (portal_manifest.output_partition_num, args.output_partition_num), (portal_manifest.input_file_wildcard, args.input_file_wildcard), (portal_manifest.input_base_dir, args.input_base_dir), (portal_manifest.output_base_dir, args.output_base_dir), (portal_manifest.raw_data_publish_dir, args.raw_data_publish_dir) ] for old, new in parameter_pairs: if old != new:
class TestDataPortalJobManager(unittest.TestCase): def setUp(self) -> None: logging.getLogger().setLevel(logging.DEBUG) self._data_portal_name = 'test_data_portal_job_manager' self._kvstore = DBClient('etcd', True) self._portal_input_base_dir = './portal_input_dir' self._portal_output_base_dir = './portal_output_dir' self._raw_data_publish_dir = 'raw_data_publish_dir' if gfile.Exists(self._portal_input_base_dir): gfile.DeleteRecursively(self._portal_input_base_dir) gfile.MakeDirs(self._portal_input_base_dir) self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)] self._data_fnames_without_success = \ ['1002/{}.data'.format(i) for i in range(100)] self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)] self._unused_fnames = ['{}.xx'.format(100)] self._all_fnames = self._data_fnames + \ self._data_fnames_without_success + \ self._csv_fnames + self._unused_fnames all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\ self._all_fnames for fname in all_fnames_with_success: fpath = os.path.join(self._portal_input_base_dir, fname) gfile.MakeDirs(os.path.dirname(fpath)) with gfile.Open(fpath, "w") as f: f.write('xxx') def tearDown(self) -> None: gfile.DeleteRecursively(self._portal_input_base_dir) def _list_input_dir(self, portal_options, file_wildcard, target_fnames, max_files_per_job=8000): portal_manifest = dp_pb.DataPortalManifest( name=self._data_portal_name, data_portal_type=dp_pb.DataPortalType.Streaming, output_partition_num=4, input_file_wildcard=file_wildcard, input_base_dir=self._portal_input_base_dir, output_base_dir=self._portal_output_base_dir, raw_data_publish_dir=self._raw_data_publish_dir, processing_job_id=-1, next_job_id=0 ) self._kvstore.set_data( common.portal_kvstore_base_dir(self._data_portal_name), text_format.MessageToString(portal_manifest)) with Timer("DataPortalJobManager initialization"): data_portal_job_manager = DataPortalJobManager( self._kvstore, self._data_portal_name, portal_options.long_running, portal_options.check_success_tag, portal_options.single_subfolder, portal_options.files_per_job_limit, max_files_per_job ) portal_job = data_portal_job_manager._sync_processing_job() target_fnames.sort() fpaths = [os.path.join(self._portal_input_base_dir, f) for f in target_fnames] self.assertEqual(len(fpaths), len(portal_job.fpaths)) for index, fpath in enumerate(fpaths): self.assertEqual(fpath, portal_job.fpaths[index]) def test_list_input_dir(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=True, single_subfolder=False, files_per_job_limit=None ) self._list_input_dir(portal_options, "*.data", self._data_fnames) def test_list_input_dir_single_folder(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=True, files_per_job_limit=None, ) self._list_input_dir( portal_options, "*.data", self._data_fnames) def test_list_input_dir_files_limit(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, files_per_job_limit=1, ) self._list_input_dir( portal_options, "*.data", self._data_fnames) portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, files_per_job_limit=150, ) self._list_input_dir( portal_options, "*.data", self._data_fnames) portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, files_per_job_limit=200, ) self._list_input_dir( portal_options, "*.data", self._data_fnames + self._data_fnames_without_success) def test_list_input_dir_over_limit(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, ) self._list_input_dir( portal_options, "*.data", self._data_fnames, max_files_per_job=100) self._list_input_dir( portal_options, "*.data", self._data_fnames + self._data_fnames_without_success, max_files_per_job=200) def test_list_input_dir_without_success_check(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, files_per_job_limit=None ) self._list_input_dir( portal_options, "*.data", self._data_fnames + self._data_fnames_without_success) def test_list_input_dir_without_wildcard(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=True, single_subfolder=False, files_per_job_limit=None ) self._list_input_dir( portal_options, None, self._data_fnames + self._csv_fnames) def test_list_input_dir_without_wildcard_and_success_check(self): portal_options = dp_pb.DataPotraMasterlOptions( use_mock_etcd=True, long_running=False, check_success_tag=False, single_subfolder=False, files_per_job_limit=None ) self._list_input_dir(portal_options, None, self._all_fnames)