Exemplo n.º 1
0
 def _cheif_barriar(self, is_chief=False, sync_times=300):
     worker_replicas = os.environ.get('REPLICA_NUM', 0)
     kvstore_type = os.environ.get('KVSTORE_TYPE', 'etcd')
     db_database, db_addr, db_username, db_password, _ = \
         get_kvstore_config(kvstore_type)
     kvstore_client = DBClient(db_database,
                               db_addr,
                               db_username,
                               db_password,
                               SYNC_PATH)
     sync_path = '%s/%s' % (os.environ['APPLICATION_ID'],
                            os.environ['WORKER_RANK'])
     logging.info('Creating a sync flag at %s', sync_path)
     kvstore_client.set_data(sync_path, "1")
     if is_chief:
         for _ in range(sync_times):
             sync_list = kvstore_client.get_prefix_kvs(
                 os.environ['APPLICATION_ID'])
             logging.info('Sync file pattern is: %s', sync_list)
             if len(sync_list) < worker_replicas:
                 logging.info('Count of ready workers is %d',
                              len(sync_list))
                 time.sleep(6)
             else:
                 break
Exemplo n.º 2
0
import etcd3
from fedlearner.common.mysql_client import DBClient
from fedlearner.data_join.common import get_kvstore_config

database, addr, username, password, base_dir = \
    get_kvstore_config('mysql')
MySQL_client = DBClient(database, addr, username, password, base_dir)
database, addr, username, password, base_dir = \
    get_kvstore_config('etcd')
(host, port) = addr.split(':')
options = [('grpc.max_send_message_length', 2**31-1),
    ('grpc.max_receive_message_length', 2**31-1)]
clnt = etcd3.client(host=host, port=port,
    grpc_options=options)
for (data, key) in clnt.get_prefix('/', sort_order='ascend'):
    if not isinstance(key.key, str):
        key = key.key.decoder()
    if not isinstance(data, str):
        data = data.decoder()
    MySQL_client.set_data(key, data)


    db_database, db_addr, db_username, db_password, db_base_dir = \
        common.get_kvstore_config(args.kvstore_type)
    use_mock_etcd = (args.kvstore_type == 'mock')
    kvstore = DBClient(db_database, db_addr, db_username, db_password,
                       db_base_dir, use_mock_etcd)
    kvstore_key = common.portal_kvstore_base_dir(args.data_portal_name)
    if kvstore.get_data(kvstore_key) is None:
        portal_manifest = dp_pb.DataPortalManifest(
            name=args.data_portal_name,
            data_portal_type=(dp_pb.DataPortalType.PSI if args.data_portal_type
                              == 'PSI' else dp_pb.DataPortalType.Streaming),
            output_partition_num=args.output_partition_num,
            input_file_wildcard=args.input_file_wildcard,
            input_base_dir=args.input_base_dir,
            output_base_dir=args.output_base_dir,
            raw_data_publish_dir=args.raw_data_publish_dir,
            processing_job_id=-1)
        kvstore.set_data(kvstore_key, text_format.\
            MessageToString(portal_manifest))

    options = dp_pb.DataPotraMasterlOptions(use_mock_etcd=use_mock_etcd,
                                            long_running=args.long_running)

    portal_master_srv = DataPortalMasterService(args.listen_port,
                                                args.data_portal_name,
                                                db_database, db_base_dir,
                                                db_addr, db_username,
                                                db_password, options)
    portal_master_srv.run()
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        db_database = 'test_mysql'
        db_addr = 'localhost:2379'
        db_username = '******'
        db_password = '******'
        db_base_dir = 'dp_test'
        data_portal_name = 'test_data_source'
        kvstore = DBClient(db_database, db_addr, db_username,
            db_password, db_base_dir, True)
        kvstore.delete_prefix(db_base_dir)
        portal_input_base_dir='./portal_upload_dir'
        portal_output_base_dir='./portal_output_dir'
        raw_data_publish_dir = 'raw_data_publish_dir'
        portal_manifest = dp_pb.DataPortalManifest(
                name=data_portal_name,
                data_portal_type=dp_pb.DataPortalType.Streaming,
                output_partition_num=4,
                input_file_wildcard="*.done",
                input_base_dir=portal_input_base_dir,
                output_base_dir=portal_output_base_dir,
                raw_data_publish_dir=raw_data_publish_dir,
                processing_job_id=-1,
                next_job_id=0
            )
        kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name),
                      text_format.MessageToString(portal_manifest))
        if gfile.Exists(portal_input_base_dir):
            gfile.DeleteRecursively(portal_input_base_dir)
        gfile.MakeDirs(portal_input_base_dir)
        all_fnames = ['{}.done'.format(i) for i in range(100)]
        all_fnames.append('{}.xx'.format(100))
        for fname in all_fnames:
            fpath = os.path.join(portal_input_base_dir, fname)
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
        portal_master_addr = 'localhost:4061'
        portal_options = dp_pb.DataPotraMasterlOptions(
                use_mock_etcd=True,
                long_running=False
            )
        data_portal_master = DataPortalMasterService(
                int(portal_master_addr.split(':')[1]),
                data_portal_name, db_database, db_base_dir,
                db_addr, db_username, db_password,
                portal_options
            )
        data_portal_master.start()

        channel = make_insecure_channel(portal_master_addr, ChannelType.INTERNAL)
        portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel)
        recv_manifest = portal_master_cli.GetDataPortalManifest(empty_pb2.Empty())
        self.assertEqual(recv_manifest.name, portal_manifest.name)
        self.assertEqual(recv_manifest.data_portal_type, portal_manifest.data_portal_type)
        self.assertEqual(recv_manifest.output_partition_num, portal_manifest.output_partition_num)
        self.assertEqual(recv_manifest.input_file_wildcard, portal_manifest.input_file_wildcard)
        self.assertEqual(recv_manifest.input_base_dir, portal_manifest.input_base_dir)
        self.assertEqual(recv_manifest.output_base_dir, portal_manifest.output_base_dir)
        self.assertEqual(recv_manifest.raw_data_publish_dir, portal_manifest.raw_data_publish_dir)
        self.assertEqual(recv_manifest.next_job_id, 1)
        self.assertEqual(recv_manifest.processing_job_id, 0)
        self._check_portal_job(kvstore, all_fnames, portal_manifest, 0)
        mapped_partition = set()
        task_0 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_0_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_0, task_0_1)
        self.assertTrue(task_0.HasField('map_task'))
        mapped_partition.add(task_0.map_task.partition_id)
        self._check_map_task(task_0.map_task, all_fnames,
                             task_0.map_task.partition_id,
                             portal_manifest)
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_0.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )
        task_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertTrue(task_1.HasField('map_task'))
        mapped_partition.add(task_1.map_task.partition_id)
        self._check_map_task(task_1.map_task, all_fnames,
                             task_1.map_task.partition_id,
                             portal_manifest)

        task_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_2.HasField('map_task'))
        mapped_partition.add(task_2.map_task.partition_id)
        self._check_map_task(task_2.map_task, all_fnames,
                             task_2.map_task.partition_id,
                             portal_manifest)

        task_3 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_3.HasField('map_task'))
        mapped_partition.add(task_3.map_task.partition_id)
        self._check_map_task(task_3.map_task, all_fnames,
                             task_3.map_task.partition_id,
                             portal_manifest)

        self.assertEqual(len(mapped_partition), portal_manifest.output_partition_num)

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_1.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        pending_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=4))
        self.assertTrue(pending_1.HasField('pending'))
        pending_2 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(pending_2.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_2.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_3.map_task.partition_id,
            part_state=dp_pb.PartState.kIdMap)
        )

        reduce_partition = set()
        task_4 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        task_4_1 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_4, task_4_1)
        self.assertTrue(task_4.HasField('reduce_task'))
        reduce_partition.add(task_4.reduce_task.partition_id)
        self._check_reduce_task(task_4.reduce_task,
                                task_4.reduce_task.partition_id,
                                portal_manifest)
        task_5 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_5.HasField('reduce_task'))
        reduce_partition.add(task_5.reduce_task.partition_id)
        self._check_reduce_task(task_5.reduce_task,
                                task_5.reduce_task.partition_id,
                                portal_manifest)
        task_6 = portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_6.HasField('reduce_task'))
        reduce_partition.add(task_6.reduce_task.partition_id)
        self._check_reduce_task(task_6.reduce_task,
                                task_6.reduce_task.partition_id,
                                portal_manifest)
        task_7= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(task_7.HasField('reduce_task'))
        reduce_partition.add(task_7.reduce_task.partition_id)
        self.assertEqual(len(reduce_partition), 4)
        self._check_reduce_task(task_7.reduce_task,
                                task_7.reduce_task.partition_id,
                                portal_manifest)

        task_8= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_8.HasField('pending'))

        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=0, partition_id=task_4.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=1, partition_id=task_5.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=2, partition_id=task_6.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )
        portal_master_cli.FinishTask(dp_pb.FinishTaskRequest(
            rank_id=3, partition_id=task_7.reduce_task.partition_id,
            part_state=dp_pb.PartState.kEventTimeReduce)
        )

        task_9= portal_master_cli.RequestNewTask(dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_9.HasField('finished'))

        data_portal_master.stop()
        gfile.DeleteRecursively(portal_input_base_dir)