Пример #1
0
 def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name,
              etcd_base_dir, etcd_addrs, options):
     master_channel = make_insecure_channel(
         master_addr,
         ChannelType.INTERNAL,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     self._master_client = dj_grpc.DataJoinMasterServiceStub(master_channel)
     self._rank_id = rank_id
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir,
                       options.use_mock_etcd)
     data_source = self._sync_data_source()
     self._data_source_name = data_source.data_source_meta.name
     self._listen_port = listen_port
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     peer_channel = make_insecure_channel(
         peer_addr,
         ChannelType.REMOTE,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     peer_client = dj_grpc.DataJoinWorkerServiceStub(peer_channel)
     self._data_join_worker = DataJoinWorker(peer_client,
                                             self._master_client, rank_id,
                                             etcd, data_source, options)
     dj_grpc.add_DataJoinWorkerServiceServicer_to_server(
         self._data_join_worker, self._server)
     self._role_repr = "leader" if data_source.role == \
             common_pb.FLRole.Leader else "follower"
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
Пример #2
0
 def __init__(self, listen_port, peer_addr, master_addr, rank_id, etcd_name,
              etcd_base_dir, etcd_addrs, options):
     master_channel = make_insecure_channel(master_addr,
                                            ChannelType.INTERNAL)
     master_client = dj_grpc.DataJoinMasterServiceStub(master_channel)
     etcd = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir)
     data_source = self.sync_data_source(master_client)
     self._data_source_name = data_source.data_source_meta.name
     self._listen_port = listen_port
     self._rank_id = rank_id
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     peer_channel = make_insecure_channel(peer_addr, ChannelType.REMOTE)
     if data_source.role == common_pb.FLRole.Leader:
         self._role_repr = "leader"
         peer_client = dj_grpc.DataJoinFollowerServiceStub(peer_channel)
         self._diw = data_join_leader.DataJoinLeader(
             peer_client, master_client, rank_id, etcd, data_source,
             options)
         dj_grpc.add_DataJoinLeaderServiceServicer_to_server(
             self._diw, self._server)
     else:
         assert data_source.role == common_pb.FLRole.Follower
         self._role_repr = "follower"
         peer_client = dj_grpc.DataJoinLeaderServiceStub(peer_channel)
         self._diw = data_join_follower.DataJoinFollower(
             peer_client, master_client, rank_id, etcd, data_source,
             options)
         dj_grpc.add_DataJoinFollowerServiceServicer_to_server(
             self._diw, self._server)
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
Пример #3
0
 def __init__(self,
              options,
              master_addr,
              rank_id,
              db_database,
              db_base_dir,
              db_addr,
              db_username,
              db_password,
              use_mock_etcd=False):
     master_channel = make_insecure_channel(
         master_addr,
         ChannelType.INTERNAL,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     self._db_database = db_database
     self._db_base_dir = db_base_dir
     self._db_addr = db_addr
     self._db_password = db_password
     self._db_username = db_username
     self._rank_id = rank_id
     self._options = options
     self._use_mock_etcd = use_mock_etcd
     self._master_client = dp_grpc.DataPortalMasterServiceStub(
         master_channel)
 def __init__(self, master_addr):
     self._master_addr = master_addr
     channel = make_insecure_channel(master_addr, ChannelType.INTERNAL)
     self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel)
     self._data_source = None
     self._raw_date_ctl = None
     self._raw_data_updated_datetime = {}
Пример #5
0
 def __init__(self, addr):
     self._lock = threading.Lock()
     self._channel = make_insecure_channel(addr, ChannelType.REMOTE)
     self._stub = dj_grpc.RsaPsiSignServiceStub(self._channel)
     self._serial_fail_cnt = 0
     self._rpc_ref_cnt = 0
     self._mark_error = False
Пример #6
0
    def __init__(self,
                 role,
                 listen_port,
                 remote_address,
                 app_id=None,
                 rank=0,
                 streaming_mode=True,
                 compression=grpc.Compression.NoCompression):
        self._role = role
        self._listen_port = listen_port
        self._remote_address = remote_address
        if app_id is None:
            app_id = 'test_trainer'
        self._app_id = app_id
        self._rank = rank
        self._streaming_mode = streaming_mode
        self._compression = compression

        self._prefetch_handlers = []
        self._data_block_handler_fn = None

        # Connection related
        self._connected = False
        self._terminated = False
        self._peer_terminated = False
        self._identifier = '%s-%s-%d-%d' % (
            app_id, role, rank, int(time.time())) # Ensure unique per run
        self._peer_identifier = ''

        # data transmit
        self._condition = threading.Condition()
        self._current_iter_id = None
        self._next_iter_id = 0
        self._received_data = {}

        # grpc client
        self._transmit_send_lock = threading.Lock()
        self._grpc_options = [
            ('grpc.max_send_message_length', 2**31-1),
            ('grpc.max_receive_message_length', 2**31-1)
        ]
        self._channel = make_insecure_channel(
            remote_address, ChannelType.REMOTE,
            options=self._grpc_options, compression=self._compression)
        self._client = tws_grpc.TrainerWorkerServiceStub(self._channel)
        self._next_send_seq_num = 0
        self._transmit_queue = queue.Queue()
        self._client_daemon = None
        self._client_daemon_shutdown_fn = None

        # server
        self._transmit_receive_lock = threading.Lock()
        self._next_receive_seq_num = 0
        self._server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10),
            options=self._grpc_options,
            compression=self._compression)
        tws_grpc.add_TrainerWorkerServiceServicer_to_server(
            Bridge.TrainerWorkerServicer(self), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
Пример #7
0
    def _transmit(self, msg):
        assert self._connected, "Cannot transmit before connect"
        with self._transmit_send_lock:
            msg.seq_num = self._next_send_seq_num
            self._next_send_seq_num += 1

            if self._streaming_mode:
                self._transmit_queue.put(msg)
                return

            while True:
                try:
                    rsp = self._client.Transmit(msg)
                    assert rsp.status.code == common_pb.STATUS_SUCCESS, \
                        "Transmit error with code %d."%rsp.status.code
                    break
                except Exception as e:  # pylint: disable=broad-except
                    logging.warning("Bridge transmit failed: %s. " \
                                    "Retry in 1 second...", repr(e))
                    self._channel.close()
                    time.sleep(1)
                    self._channel = make_insecure_channel(
                        self._remote_address,
                        ChannelType.REMOTE,
                        options=self._grpc_options,
                        compression=self._compression)
                    self._client = make_ready_client(self._channel)
                    self._check_remote_heartbeat()
Пример #8
0
 def __init__(self, id_batch_fetcher, max_flying_item,
              process_pool_executor, public_key,
              leader_signer_addr):
     super(FollowerPsiRsaSigner, self).__init__(id_batch_fetcher,
                                                max_flying_item,
                                                process_pool_executor)
     self._public_key = public_key
     channel = make_insecure_channel(leader_signer_addr,
                                     ChannelType.REMOTE)
     self._leader_signer_stub = dj_grpc.RsaPsiSignServiceStub(channel)
Пример #9
0
    def __init__(self, addr, role, task_id):
        self._addr = addr
        self._role = role
        self._task_id = task_id

        channel = make_insecure_channel(self._addr, ChannelType.INTERNAL)
        self._stub = tm_grpc.TrainerMasterServiceStub(channel)
        self._request = tm_pb.DataBlockRequest()
        if self._role == 'leader':
            self._request.worker_rank = self._task_id
Пример #10
0
 def __init__(self, address, worker_rank):
     channel = make_insecure_channel(
         address,
         mode=ChannelType.INTERNAL,
         options=(
             ('grpc.max_send_message_length', -1),
             ('grpc.max_receive_message_length', -1),
             ('grpc.max_reconnect_backoff_ms', 1000),
         ))
     client = tm_grpc.TrainerMasterServiceStub(channel)
     super(TrainerMasterClient, self).__init__(client, worker_rank)
Пример #11
0
 def __init__(self, addr):
     self._lock = threading.Lock()
     self._channel = make_insecure_channel(
         addr,
         ChannelType.REMOTE,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     self._stub = dj_grpc.RsaPsiSignServiceStub(self._channel)
     self._serial_fail_cnt = 0
     self._rpc_ref_cnt = 0
     self._mark_error = False
Пример #12
0
 def __init__(self, options, master_addr, rank_id, etcd_name,
              etcd_base_dir, etcd_addrs, use_mock_etcd=False):
     master_channel = make_insecure_channel(
       master_addr, ChannelType.INTERNAL)
     self._etcd_name = etcd_name
     self._etcd_base_dir = etcd_base_dir
     self._etcd_addrs = etcd_addrs
     self._rank_id = rank_id
     self._options = options
     self._use_mock_etcd = use_mock_etcd
     self._master_client = dp_grpc.DataPortalMasterServiceStub(
         master_channel)
Пример #13
0
 def __init__(self, listen_port, peer_addr, data_source_name, etcd_name,
              etcd_base_dir, etcd_addrs):
     channel = make_insecure_channel(peer_addr, ChannelType.REMOTE)
     peer_client = dj_grpc.DataJoinMasterServiceStub(channel)
     self._data_source_name = data_source_name
     self._listen_port = listen_port
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     self._dim = DataJoinMaster(peer_client, data_source_name, etcd_name,
                                etcd_addrs, etcd_base_dir)
     dj_grpc.add_DataJoinMasterServiceServicer_to_server(
         self._dim, self._server)
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
Пример #14
0
    def _launch_masters(self):
        self._master_addr_l = 'localhost:4061'
        self._master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        self._master_l = data_join_master.DataJoinMasterService(
            int(self._master_addr_l.split(':')[1]), self._master_addr_f,
            self._data_source_name, self._db_database, self._db_base_dir_l,
            self._db_addr, self._db_username_l, self._db_password_l,
            master_options)
        self._master_f = data_join_master.DataJoinMasterService(
            int(self._master_addr_f.split(':')[1]), self._master_addr_l,
            self._data_source_name, self._db_database, self._db_base_dir_f,
            self._db_addr, self._db_username_f, self._db_password_f,
            master_options)
        self._master_f.start()
        self._master_l.start()
        channel_l = make_insecure_channel(self._master_addr_l,
                                          ChannelType.INTERNAL)
        self._master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(self._master_addr_f,
                                          ChannelType.INTERNAL)
        self._master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=self._data_source_f.data_source_meta)
            dss_l = self._master_client_l.GetDataSourceStatus(req_l)
            dss_f = self._master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)
        logging.info("masters turn into Processing state")
Пример #15
0
    def __init__(self,
                 role,
                 listen_port,
                 remote_address,
                 app_id='test_trainer',
                 rank=0,
                 streaming_mode=True):
        self._role = role
        self._listen_port = listen_port
        self._remote_address = remote_address
        self._app_id = app_id
        self._rank = rank
        self._streaming_mode = streaming_mode

        self._prefetch_handlers = []
        self._data_block_handler_fn = None
        self._connected = False

        # data transmit
        self._condition = threading.Condition()
        self._current_iter_id = None
        self._next_iter_id = 0
        self._received_data = {}

        # grpc client
        self._grpc_options = [
            ('grpc.max_send_message_length', 2**31-1),
            ('grpc.max_receive_message_length', 2**31-1)
        ]
        channel = make_insecure_channel(
            remote_address, ChannelType.REMOTE, options=self._grpc_options)
        self._transmit_send_lock = threading.Lock()
        self._client = tws_grpc.TrainerWorkerServiceStub(channel)
        self._next_send_seq_num = 0
        self._transmit_queue = queue.Queue()
        self._client_daemon = None
        self._client_daemon_shutdown_fn = None
        self._keepalive_daemon = None
        self._keepalive_daemon_shutdown_fn = None

        # server
        self._transmit_receive_lock = threading.Lock()
        self._next_receive_seq_num = 0
        self._server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10),
            options=self._grpc_options)
        tws_grpc.add_TrainerWorkerServiceServicer_to_server(
            Bridge.TrainerWorkerServicer(self), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
Пример #16
0
 def _rpc_with_retry(self, sender, err_log):
     while True:
         with self._client_lock:
             try:
                 return sender()
             except Exception as e:  # pylint: disable=broad-except
                 logging.warning("%s: %s. Retry in 1s...", err_log, repr(e))
                 self._channel.close()
                 time.sleep(1)
                 self._channel = make_insecure_channel(
                     self._remote_address,
                     ChannelType.REMOTE,
                     options=self._grpc_options,
                     compression=self._compression)
                 self._client = make_ready_client(self._channel)
                 self._check_remote_heartbeat(self._client)
Пример #17
0
 def __init__(self,
              options,
              master_addr,
              rank_id,
              kvstore_type,
              use_mock_etcd=False):
     master_channel = make_insecure_channel(
         master_addr,
         ChannelType.INTERNAL,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     self._kvstore_type = kvstore_type
     self._rank_id = rank_id
     self._options = options
     self._use_mock_etcd = use_mock_etcd
     self._master_client = dp_grpc.DataPortalMasterServiceStub(
         master_channel)
Пример #18
0
 def __init__(self, listen_port, peer_addr, data_source_name, kvstore_type,
              options):
     channel = make_insecure_channel(
         peer_addr,
         ChannelType.REMOTE,
         options=[('grpc.max_send_message_length', 2**31 - 1),
                  ('grpc.max_receive_message_length', 2**31 - 1)])
     peer_client = dj_grpc.DataJoinMasterServiceStub(channel)
     self._data_source_name = data_source_name
     self._listen_port = listen_port
     self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
     self._data_join_master = DataJoinMaster(peer_client, data_source_name,
                                             kvstore_type, options)
     dj_grpc.add_DataJoinMasterServiceServicer_to_server(
         self._data_join_master, self._server)
     self._server.add_insecure_port('[::]:%d' % listen_port)
     self._server_started = False
Пример #19
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.output_base_dir = "./ds_output_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.output_base_dir = "./ds_output_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs, options)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=-1,
            join_example=empty_pb2.Empty())

        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.join_example_rep.rank_id, 0)
        self.assertEqual(rdrsp.manifest.join_example_rep.state,
                         dj_pb.JoinExampleState.Joining)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=-1,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)
        #check idempotent
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.rank_id, 1)
        self.assertEqual(rdrsp.manifest.sync_example_id_rep.state,
                         dj_pb.SyncExampleIdState.Syncing)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq1 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=1,
            partition_id=0,
            sync_example_id=empty_pb2.Empty())

        try:
            rsp = client_l.FinishJoinPartition(rdreq1)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq2 = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            join_example=empty_pb2.Empty())
        try:
            rsp = client_l.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 1)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 3)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=5)),
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        rsp = client_l.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertFalse(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)
        rsp = client_l.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 5)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 0)
        self.assertEqual(rsp.timestamp.nanos, 0)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 1)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 1)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='a',
                        timestamp=timestamp_pb2.Timestamp(seconds=1)),
                    dj_pb.RawDataMeta(
                        file_path='b',
                        timestamp=timestamp_pb2.Timestamp(seconds=2))
                ]))
        rsp = client_f.AddRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertFalse(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)
        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
        )
        rsp = client_f.GetRawDataLatestTimeStamp(rdreq)
        self.assertEqual(rsp.status.code, 0)
        self.assertTrue(rsp.HasField('timestamp'))
        self.assertEqual(rsp.timestamp.seconds, 2)
        self.assertEqual(rsp.timestamp.nanos, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_l = client_l.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_l is not None)
        self.assertTrue(manifest_l.finished)
        self.assertEqual(manifest_l.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=False,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=4))
                ]))
        try:
            rsp = client_l.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq1)
        self.assertEqual(rsp.code, 0)

        try:
            rsp = client_f.FinishJoinPartition(rdreq2)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            finish_raw_data=empty_pb2.Empty())
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishRawData(rdreq)
        self.assertEqual(rsp.code, 0)

        manifest_f = client_f.QueryRawDataManifest(rdreq)
        self.assertTrue(manifest_f is not None)
        self.assertTrue(manifest_f.finished)
        self.assertEqual(manifest_f.next_process_index, 2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            partition_id=0,
            added_raw_data_metas=dj_pb.AddedRawDataMetas(
                dedup=True,
                raw_data_metas=[
                    dj_pb.RawDataMeta(
                        file_path='x',
                        timestamp=timestamp_pb2.Timestamp(seconds=3))
                ]))
        try:
            rsp = client_f.AddRawData(rdreq)
        except Exception as e:
            self.assertTrue(True)
        else:
            self.assertTrue(False)

        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_f.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)
        #check idempotent
        rsp = client_l.FinishJoinPartition(rdreq2)
        self.assertEqual(rsp.code, 0)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = client_l.GetDataSourceStatus(req_l)
            dss_f = client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Finished and \
                    dss_f.state == common_pb.DataSourceState.Finished:
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Пример #20
0
    def _client_daemon_fn(self):
        stop_event = threading.Event()
        generator = None
        channel = make_insecure_channel(self._remote_address,
                                        ChannelType.REMOTE,
                                        options=self._grpc_options,
                                        compression=self._compression)
        client = make_ready_client(channel, stop_event)

        lock = threading.Lock()
        resend_list = collections.deque()

        def shutdown_fn():
            while True:
                with lock:
                    if len(resend_list) > 0:
                        logging.debug(
                            "Waiting for resend queue's being cleaned. "
                            "Resend queue size: %d", len(resend_list))
                        time.sleep(1)
                    else:
                        logging.debug('Resend queue is empty and we can shut '
                                      'down client daemon safely.')
                        break

            stop_event.set()
            if generator is not None:
                generator.cancel()
            return generator.result()

        self._client_daemon_shutdown_fn = shutdown_fn

        while not stop_event.is_set():
            try:

                def iterator():
                    with lock:
                        resend_msgs = list(resend_list)
                    for item in resend_msgs:
                        logging.warning("Streaming resend message seq_num=%d",
                                        item.seq_num)
                        yield item
                    while True:
                        item = self._transmit_queue.get()
                        with lock:
                            resend_list.append(item)
                        logging.debug("Streaming send message seq_num=%d",
                                      item.seq_num)
                        yield item

                generator = client.StreamTransmit(iterator())
                for response in generator:
                    if response.status.code == common_pb.STATUS_SUCCESS:
                        logging.debug(
                            "Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_DUPLICATED:
                        logging.debug(
                            "Resent Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                        common_pb.STATUS_MESSAGE_MISSING:
                        raise RuntimeError("Message with seq_num=%d is "
                                           "missing!" %
                                           (response.next_seq_num - 1))
                    else:
                        raise RuntimeError("Trainsmit failed with %d" %
                                           response.status.code)
                    with lock:
                        while resend_list and \
                                resend_list[0].seq_num < response.next_seq_num:
                            resend_list.popleft()
                        min_seq_num_to_resend = resend_list[0].seq_num \
                            if resend_list else "NaN"
                        logging.debug(
                            "Resend queue size: %d, starting from seq_num=%s",
                            len(resend_list), min_seq_num_to_resend)
            except Exception as e:  # pylint: disable=broad-except
                if not stop_event.is_set():
                    logging.warning("Bridge streaming broken: %s.", repr(e))
            finally:
                generator.cancel()
                channel.close()
                logging.warning(
                    "Restarting streaming: resend queue size: %d, "
                    "starting from seq_num=%s", len(resend_list),
                    resend_list and resend_list[0].seq_num or "NaN")
                channel = make_insecure_channel(self._remote_address,
                                                ChannelType.REMOTE,
                                                options=self._grpc_options,
                                                compression=self._compression)
                client = make_ready_client(channel, stop_event)
                self._check_remote_heartbeat()
Пример #21
0
    def test_api(self):
        logging.getLogger().setLevel(logging.DEBUG)
        kvstore_type = 'etcd'
        db_base_dir = 'dp_test'
        os.environ['ETCD_BASE_DIR'] = db_base_dir
        data_portal_name = 'test_data_source'
        kvstore = DBClient(kvstore_type, True)
        kvstore.delete_prefix(db_base_dir)
        portal_input_base_dir = './portal_upload_dir'
        portal_output_base_dir = './portal_output_dir'
        raw_data_publish_dir = 'raw_data_publish_dir'
        portal_manifest = dp_pb.DataPortalManifest(
            name=data_portal_name,
            data_portal_type=dp_pb.DataPortalType.Streaming,
            output_partition_num=4,
            input_file_wildcard="*.done",
            input_base_dir=portal_input_base_dir,
            output_base_dir=portal_output_base_dir,
            raw_data_publish_dir=raw_data_publish_dir,
            processing_job_id=-1,
            next_job_id=0)
        kvstore.set_data(common.portal_kvstore_base_dir(data_portal_name),
                         text_format.MessageToString(portal_manifest))
        if gfile.Exists(portal_input_base_dir):
            gfile.DeleteRecursively(portal_input_base_dir)
        gfile.MakeDirs(portal_input_base_dir)
        all_fnames = ['1001/{}.done'.format(i) for i in range(100)]
        all_fnames.append('{}.xx'.format(100))
        all_fnames.append('1001/_SUCCESS')
        for fname in all_fnames:
            fpath = os.path.join(portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
        portal_master_addr = 'localhost:4061'
        portal_options = dp_pb.DataPotraMasterlOptions(
            use_mock_etcd=True,
            long_running=False,
            check_success_tag=True,
        )
        data_portal_master = DataPortalMasterService(
            int(portal_master_addr.split(':')[1]), data_portal_name,
            kvstore_type, portal_options)
        data_portal_master.start()

        channel = make_insecure_channel(portal_master_addr,
                                        ChannelType.INTERNAL)
        portal_master_cli = dp_grpc.DataPortalMasterServiceStub(channel)
        recv_manifest = portal_master_cli.GetDataPortalManifest(
            empty_pb2.Empty())
        self.assertEqual(recv_manifest.name, portal_manifest.name)
        self.assertEqual(recv_manifest.data_portal_type,
                         portal_manifest.data_portal_type)
        self.assertEqual(recv_manifest.output_partition_num,
                         portal_manifest.output_partition_num)
        self.assertEqual(recv_manifest.input_file_wildcard,
                         portal_manifest.input_file_wildcard)
        self.assertEqual(recv_manifest.input_base_dir,
                         portal_manifest.input_base_dir)
        self.assertEqual(recv_manifest.output_base_dir,
                         portal_manifest.output_base_dir)
        self.assertEqual(recv_manifest.raw_data_publish_dir,
                         portal_manifest.raw_data_publish_dir)
        self.assertEqual(recv_manifest.next_job_id, 1)
        self.assertEqual(recv_manifest.processing_job_id, 0)
        self._check_portal_job(kvstore, all_fnames, portal_manifest, 0)
        mapped_partition = set()
        task_0 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        task_0_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_0, task_0_1)
        self.assertTrue(task_0.HasField('map_task'))
        mapped_partition.add(task_0.map_task.partition_id)
        self._check_map_task(task_0.map_task, all_fnames,
                             task_0.map_task.partition_id, portal_manifest)
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=0,
                                    partition_id=task_0.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))
        task_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertTrue(task_1.HasField('map_task'))
        mapped_partition.add(task_1.map_task.partition_id)
        self._check_map_task(task_1.map_task, all_fnames,
                             task_1.map_task.partition_id, portal_manifest)

        task_2 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_2.HasField('map_task'))
        mapped_partition.add(task_2.map_task.partition_id)
        self._check_map_task(task_2.map_task, all_fnames,
                             task_2.map_task.partition_id, portal_manifest)

        task_3 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_3.HasField('map_task'))
        mapped_partition.add(task_3.map_task.partition_id)
        self._check_map_task(task_3.map_task, all_fnames,
                             task_3.map_task.partition_id, portal_manifest)

        self.assertEqual(len(mapped_partition),
                         portal_manifest.output_partition_num)

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=0,
                                    partition_id=task_1.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        pending_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=4))
        self.assertTrue(pending_1.HasField('pending'))
        pending_2 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(pending_2.HasField('pending'))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=1,
                                    partition_id=task_2.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(rank_id=2,
                                    partition_id=task_3.map_task.partition_id,
                                    part_state=dp_pb.PartState.kIdMap))

        reduce_partition = set()
        task_4 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        task_4_1 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=0))
        self.assertEqual(task_4, task_4_1)
        self.assertTrue(task_4.HasField('reduce_task'))
        reduce_partition.add(task_4.reduce_task.partition_id)
        self._check_reduce_task(task_4.reduce_task,
                                task_4.reduce_task.partition_id,
                                portal_manifest)
        task_5 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=1))
        self.assertTrue(task_5.HasField('reduce_task'))
        reduce_partition.add(task_5.reduce_task.partition_id)
        self._check_reduce_task(task_5.reduce_task,
                                task_5.reduce_task.partition_id,
                                portal_manifest)
        task_6 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=2))
        self.assertTrue(task_6.HasField('reduce_task'))
        reduce_partition.add(task_6.reduce_task.partition_id)
        self._check_reduce_task(task_6.reduce_task,
                                task_6.reduce_task.partition_id,
                                portal_manifest)
        task_7 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=3))
        self.assertTrue(task_7.HasField('reduce_task'))
        reduce_partition.add(task_7.reduce_task.partition_id)
        self.assertEqual(len(reduce_partition), 4)
        self._check_reduce_task(task_7.reduce_task,
                                task_7.reduce_task.partition_id,
                                portal_manifest)

        task_8 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_8.HasField('pending'))

        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=0,
                partition_id=task_4.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=1,
                partition_id=task_5.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=2,
                partition_id=task_6.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))
        portal_master_cli.FinishTask(
            dp_pb.FinishTaskRequest(
                rank_id=3,
                partition_id=task_7.reduce_task.partition_id,
                part_state=dp_pb.PartState.kEventTimeReduce))

        task_9 = portal_master_cli.RequestNewTask(
            dp_pb.NewTaskRequest(rank_id=5))
        self.assertTrue(task_9.HasField('finished'))

        data_portal_master.stop()
        gfile.DeleteRecursively(portal_input_base_dir)
Пример #22
0
    def __init__(self,
                 listen_address,
                 remote_address,
                 token=None,
                 max_workers=16,
                 compression=grpc.Compression.Gzip,
                 heartbeat_timeout=120,
                 retry_interval=2,
                 stats_client=None):
        # identifier
        self._identifier = uuid.uuid4().hex[:16]
        self._peer_identifier = ""
        self._token = token if token else ""

        # lock & condition
        self._lock = threading.RLock()
        self._condition = threading.Condition(self._lock)

        # heartbeat
        if heartbeat_timeout <= 0:
            raise ValueError("[Channel] heartbeat_timeout must be positive")
        self._heartbeat_timeout = heartbeat_timeout
        self._heartbeat_interval = self._heartbeat_timeout / 6
        self._next_heartbeat_at = 0
        self._heartbeat_timeout_at = 0
        self._peer_heartbeat_timeout_at = 0

        self._connected_at = 0
        self._closed_at = 0
        self._peer_connected_at = 0
        self._peer_closed_at = 0

        if retry_interval <= 0:
            raise ValueError("[Channel] retry_interval must be positive")
        self._retry_interval = retry_interval
        self._next_retry_at = 0

        self._ready_event = threading.Event()
        self._closed_event = threading.Event()
        self._error_event = threading.Event()

        # channel state
        self._state = Channel.State.IDLE
        self._state_thread = None
        self._error = None
        self._event_callbacks = {}

        # channel
        self._remote_address = remote_address
        self._channel = make_insecure_channel(
            self._remote_address,
            mode=ChannelType.REMOTE,
            options=(
                ('grpc.max_send_message_length', -1),
                ('grpc.max_receive_message_length', -1),
                ('grpc.max_reconnect_backoff_ms', 1000),
            ),
            compression=compression)
        self._channel_interceptor = ClientInterceptor(
            identifier=self._identifier,
            retry_interval=self._retry_interval,
            wait_fn=self.wait_for_ready,
            check_fn=self._channel_response_check_fn,
            stats_client=stats_client)
        self._channel = grpc.intercept_channel(self._channel,
                                               self._channel_interceptor)

        # server
        self._listen_address = listen_address
        self._server_thread_pool = futures.ThreadPoolExecutor(
            max_workers=max_workers)
        self._server_interceptor = ServerInterceptor()
        self._server = grpc.server(self._server_thread_pool,
                                   options=(
                                       ('grpc.max_send_message_length', -1),
                                       ('grpc.max_receive_message_length', -1),
                                   ),
                                   interceptors=(self._server_interceptor, ),
                                   compression=compression)
        self._server.add_insecure_port(self._listen_address)

        # channel client & server
        self._channel_call = channel_pb2_grpc.ChannelStub(self._channel)
        channel_pb2_grpc.add_ChannelServicer_to_server(Channel._Servicer(self),
                                                       self._server)

        # stats
        self._stats_client = stats_client or stats.NoneClient()
Пример #23
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        etcd_f.delete_prefix(
            common.data_source_etcd_base_dir(data_source_name))
        data_source_l = common_pb.DataSource()
        self.raw_data_pub_dir_l = './raw_data_pub_dir_l'
        data_source_l.raw_data_sub_dir = self.raw_data_pub_dir_l
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        self.raw_data_pub_dir_f = './raw_data_pub_dir_f'
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.raw_data_sub_dir = self.raw_data_pub_dir_f
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_l, data_source_l)
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        common.commit_data_source(etcd_f, data_source_f)
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True)

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]),
            master_addr_f,
            data_source_name,
            etcd_name,
            etcd_base_dir_l,
            etcd_addrs,
            master_options,
        )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs, master_options)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            req_l = dj_pb.DataSourceRequest(
                data_source_meta=data_source_l.data_source_meta)
            req_f = dj_pb.DataSourceRequest(
                data_source_meta=data_source_f.data_source_meta)
            dss_l = master_client_l.GetDataSourceStatus(req_l)
            dss_f = master_client_f.GetDataSourceStatus(req_f)
            self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
            self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
            if dss_l.state == common_pb.DataSourceState.Processing and \
                    dss_f.state == common_pb.DataSourceState.Processing:
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        self.raw_data_publisher_l = raw_data_publisher.RawDataPublisher(
            self.etcd_l, self.raw_data_pub_dir_l)
        self.raw_data_publisher_f = raw_data_publisher.RawDataPublisher(
            self.etcd_f, self.raw_data_pub_dir_f)
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.worker_options = dj_pb.DataJoinWorkerOptions(
            use_mock_etcd=True,
            raw_data_options=dj_pb.RawDataOptions(raw_data_iter='TF_RECORD',
                                                  compressed_type=''),
            example_id_dump_options=dj_pb.ExampleIdDumpOptions(
                example_id_dump_interval=1, example_id_dump_threshold=1024),
            example_joiner_options=dj_pb.ExampleJoinerOptions(
                example_joiner='STREAM_JOINER',
                min_matching_window=64,
                max_matching_window=256,
                data_block_dump_interval=30,
                data_block_dump_threshold=1000),
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=512, max_flying_item=2048),
            data_block_builder_options=dj_pb.WriterOptions(
                output_writer='TF_RECORD'))

        self.total_index = 1 << 13
Пример #24
0
    def test_api(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f)
        etcd_l.delete_prefix(data_source_name)
        etcd_f.delete_prefix(data_source_name)
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 1
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_meta.min_matching_window = 32
        data_source_meta.max_matching_window = 1024
        data_source_meta.data_source_type = common_pb.DataSourceType.Sequential
        data_source_meta.max_example_in_data_block = 1000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        etcd_l.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_l))
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        etcd_f.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_f))

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Processing
                    and rsp_f.state == common_pb.DataSourceState.Processing):
                break
            else:
                time.sleep(2)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertTrue(rdrsp.status.code == 0)
        self.assertFalse(rdrsp.HasField('manifest'))
        self.assertFalse(rdrsp.HasField('finished'))

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Syncing)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(frreq)
        self.assertEqual(frrsp.code, 0)
        rdrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            sync_example_id=dj_pb.SyncExampleIdRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=-1))
        rdrsp = client_l.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        rdreq = dj_pb.RawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        rdrsp = client_f.RequestJoinPartition(rdreq)
        self.assertEqual(rdrsp.status.code, 0)
        self.assertTrue(rdrsp.HasField('manifest'))
        self.assertEqual(rdrsp.manifest.state, dj_pb.Joining)
        self.assertEqual(rdrsp.manifest.allocated_rank_id, 0)
        self.assertEqual(rdrsp.manifest.partition_id, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_l.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frrsp = client_l.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        frreq = dj_pb.FinishRawDataRequest(
            data_source_meta=data_source_f.data_source_meta,
            rank_id=0,
            join_example=dj_pb.JoinExampleRequest(partition_id=0))
        frrsp = client_f.FinishJoinPartition(rdreq)
        self.assertEqual(frrsp.code, 0)

        while True:
            rsp_l = client_l.GetDataSourceState(data_source_l.data_source_meta)
            rsp_f = client_f.GetDataSourceState(data_source_l.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Finished
                    and rsp_f.state == common_pb.DataSourceState.Finished):
                break
            else:
                time.sleep(2)

        master_l.stop()
        master_f.stop()
Пример #25
0
 def __init__(self, addr):
     self._addr = addr
     channel = make_insecure_channel(addr, mode=ChannelType.REMOTE)
     self._stub = ss_grpc.SchedulerStub(channel)
Пример #26
0
                          'only support add/finish')
 parser.add_argument('master_addr', type=str,
                     help='the addr(uuid) of local data_join_master')
 parser.add_argument('partition_id', type=int,
                     help='the partition to control')
 parser.add_argument('--files', type=str, nargs='+',
                     help='the need raw data fnames')
 parser.add_argument('--src_dir', type=str,
                     help='the directory of input raw data. The input '\
                          'file sequence is sorted by file name and rank '\
                          'after raw data input by --files')
 parser.add_argument('--dedup', action='store_true',
                     help='dedup the input files, otherwise, '\
                          'error if dup input files')
 args = parser.parse_args()
 master_channel = make_insecure_channel(args.master_addr,
                                        ChannelType.INTERNAL)
 master_cli = dj_grpc.DataJoinMasterServiceStub(master_channel)
 data_src = master_cli.GetDataSource(empty_pb2.Empty())
 rdc = RawDataController(data_src, master_cli)
 if args.cmd == 'add':
     all_fpaths = []
     if args.files is not None:
         for fp in args.files:
             all_fpaths.append(fp)
     if args.src_dir is not None:
         dir_fpaths = \
                 [path.join(args.src_dir, f)
                  for f in gfile.ListDirectory(args.src_dir)
                  if not gfile.IsDirectory(path.join(args.src_dir, f))]
         dir_fpaths.sort()
         all_fpaths += dir_fpaths
Пример #27
0
    def _inner_test_round(self, start_index):
        for i in range(self.data_source_l.data_source_meta.partition_num):
            self.generate_raw_data(
                    start_index, self.etcd_l, self.raw_data_publisher_l,
                    self.data_source_l, self.raw_data_dir_l, i, 2048, 64,
                    'leader_key_partition_{}'.format(i) + ':{}',
                    'leader_value_partition_{}'.format(i) + ':{}'
                )
            self.generate_raw_data(
                    start_index, self.etcd_f, self.raw_data_publisher_f,
                    self.data_source_f, self.raw_data_dir_f, i, 4096, 128,
                    'follower_key_partition_{}'.format(i) + ':{}',
                    'follower_value_partition_{}'.format(i) + ':{}'
                )

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_options = dj_pb.DataJoinMasterOptions(use_mock_etcd=True,
                                                     batch_mode=True)
        master_l = data_join_master.DataJoinMasterService(
                int(master_addr_l.split(':')[1]), master_addr_f,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, master_options,
            )
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
                int(master_addr_f.split(':')[1]), master_addr_l,
                self.data_source_name, self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, master_options
            )
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Processing and \
                        dss_f.state == common_pb.DataSourceState.Processing:
                    break
            except Exception as e:
                pass
            time.sleep(2)

        worker_addr_l = 'localhost:4161'
        worker_addr_f = 'localhost:4162'

        worker_l = data_join_worker.DataJoinWorkerService(
                int(worker_addr_l.split(':')[1]),
                worker_addr_f, master_addr_l, 0,
                self.etcd_name, self.etcd_base_dir_l,
                self.etcd_addrs, self.worker_options
            )

        worker_f = data_join_worker.DataJoinWorkerService(
                int(worker_addr_f.split(':')[1]),
                worker_addr_l, master_addr_f, 0,
                self.etcd_name, self.etcd_base_dir_f,
                self.etcd_addrs, self.worker_options
            )

        th_l = threading.Thread(target=worker_l.run, name='worker_l')
        th_f = threading.Thread(target=worker_f.run, name='worker_f')

        th_l.start()
        th_f.start()

        while True:
            try:
                req_l = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_l.data_source_meta
                    )
                req_f = dj_pb.DataSourceRequest(
                        data_source_meta=self.data_source_f.data_source_meta
                    )
                dss_l = master_client_l.GetDataSourceStatus(req_l)
                dss_f = master_client_f.GetDataSourceStatus(req_f)
                self.assertEqual(dss_l.role, common_pb.FLRole.Leader)
                self.assertEqual(dss_f.role, common_pb.FLRole.Follower)
                if dss_l.state == common_pb.DataSourceState.Ready and \
                        dss_f.state == common_pb.DataSourceState.Ready:
                    break
            except Exception as e: #xx
                pass
            time.sleep(2)

        th_l.join()
        th_f.join()
        master_l.stop()
        master_f.stop()
Пример #28
0
    def _client_daemon_fn(self):
        stop_event = threading.Event()
        generator = None
        channel = make_insecure_channel(self._remote_address,
                                        ChannelType.REMOTE,
                                        options=self._grpc_options,
                                        compression=self._compression)
        client = make_ready_client(channel, stop_event)

        def shutdown_fn():
            while self._transmit_queue.size():
                logging.debug(
                    "Waiting for message queue's being cleaned. "
                    "Queue size: %d", self._transmit_queue.size())
                time.sleep(1)

            stop_event.set()
            if generator is not None:
                generator.cancel()

        self._client_daemon_shutdown_fn = shutdown_fn

        while not stop_event.is_set():
            try:

                def iterator():
                    while True:
                        item = self._transmit_queue.get()
                        logging.debug("Streaming send message seq_num=%d",
                                      item.seq_num)
                        yield item

                generator = client.StreamTransmit(iterator())
                for response in generator:
                    if response.status.code == common_pb.STATUS_SUCCESS:
                        self._transmit_queue.confirm(response.next_seq_num)
                        logging.debug(
                            "Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                            common_pb.STATUS_MESSAGE_DUPLICATED:
                        self._transmit_queue.confirm(response.next_seq_num)
                        logging.debug(
                            "Resent Message with seq_num=%d is "
                            "confirmed", response.next_seq_num - 1)
                    elif response.status.code == \
                            common_pb.STATUS_MESSAGE_MISSING:
                        self._transmit_queue.resend(response.next_seq_num)
                    else:
                        raise RuntimeError("Trainsmit failed with %d" %
                                           response.status.code)
            except Exception as e:  # pylint: disable=broad-except
                if not stop_event.is_set():
                    logging.warning("Bridge streaming broken: %s.", repr(e))
                    metrics.emit_counter('reconnect_counter', 1)
            finally:
                generator.cancel()
                channel.close()
                time.sleep(1)
                self._transmit_queue.resend(-1)
                channel = make_insecure_channel(self._remote_address,
                                                ChannelType.REMOTE,
                                                options=self._grpc_options,
                                                compression=self._compression)
                client = make_ready_client(channel, stop_event)
                self._check_remote_heartbeat(client)
Пример #29
0
    def setUp(self):
        etcd_name = 'test_etcd'
        etcd_addrs = 'localhost:2379'
        etcd_base_dir_l = 'byefl_l'
        etcd_base_dir_f = 'byefl_f'
        data_source_name = 'test_data_source'
        etcd_l = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_l, True)
        etcd_f = EtcdClient(etcd_name, etcd_addrs, etcd_base_dir_f, True)
        etcd_l.delete_prefix(data_source_name)
        etcd_f.delete_prefix(data_source_name)
        data_source_l = common_pb.DataSource()
        data_source_l.role = common_pb.FLRole.Leader
        data_source_l.state = common_pb.DataSourceState.Init
        data_source_l.data_block_dir = "./data_block_l"
        data_source_l.raw_data_dir = "./raw_data_l"
        data_source_l.example_dumped_dir = "./example_dumped_l"
        data_source_f = common_pb.DataSource()
        data_source_f.role = common_pb.FLRole.Follower
        data_source_f.state = common_pb.DataSourceState.Init
        data_source_f.data_block_dir = "./data_block_f"
        data_source_f.raw_data_dir = "./raw_data_f"
        data_source_f.example_dumped_dir = "./example_dumped_f"
        data_source_meta = common_pb.DataSourceMeta()
        data_source_meta.name = data_source_name
        data_source_meta.partition_num = 2
        data_source_meta.start_time = 0
        data_source_meta.end_time = 100000000
        data_source_meta.min_matching_window = 64
        data_source_meta.max_matching_window = 128
        data_source_meta.data_source_type = common_pb.DataSourceType.Sequential
        data_source_meta.max_example_in_data_block = 1000
        data_source_l.data_source_meta.MergeFrom(data_source_meta)
        etcd_l.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_l))
        data_source_f.data_source_meta.MergeFrom(data_source_meta)
        etcd_f.set_data(os.path.join(data_source_name, 'master'),
                        text_format.MessageToString(data_source_f))
        customized_options.set_use_mock_etcd()

        master_addr_l = 'localhost:4061'
        master_addr_f = 'localhost:4062'
        master_l = data_join_master.DataJoinMasterService(
            int(master_addr_l.split(':')[1]), master_addr_f, data_source_name,
            etcd_name, etcd_base_dir_l, etcd_addrs)
        master_l.start()
        master_f = data_join_master.DataJoinMasterService(
            int(master_addr_f.split(':')[1]), master_addr_l, data_source_name,
            etcd_name, etcd_base_dir_f, etcd_addrs)
        master_f.start()
        channel_l = make_insecure_channel(master_addr_l, ChannelType.INTERNAL)
        master_client_l = dj_grpc.DataJoinMasterServiceStub(channel_l)
        channel_f = make_insecure_channel(master_addr_f, ChannelType.INTERNAL)
        master_client_f = dj_grpc.DataJoinMasterServiceStub(channel_f)

        while True:
            rsp_l = master_client_l.GetDataSourceState(
                data_source_l.data_source_meta)
            rsp_f = master_client_f.GetDataSourceState(
                data_source_f.data_source_meta)
            self.assertEqual(rsp_l.status.code, 0)
            self.assertEqual(rsp_l.role, common_pb.FLRole.Leader)
            self.assertEqual(rsp_l.data_source_type,
                             common_pb.DataSourceType.Sequential)
            self.assertEqual(rsp_f.status.code, 0)
            self.assertEqual(rsp_f.role, common_pb.FLRole.Follower)
            self.assertEqual(rsp_f.data_source_type,
                             common_pb.DataSourceType.Sequential)
            if (rsp_l.state == common_pb.DataSourceState.Processing
                    and rsp_f.state == common_pb.DataSourceState.Processing):
                break
            else:
                time.sleep(2)

        self.master_client_l = master_client_l
        self.master_client_f = master_client_f
        self.master_addr_l = master_addr_l
        self.master_addr_f = master_addr_f
        self.etcd_l = etcd_l
        self.etcd_f = etcd_f
        self.data_source_l = data_source_l
        self.data_source_f = data_source_f
        self.master_l = master_l
        self.master_f = master_f
        self.data_source_name = data_source_name,
        self.etcd_name = etcd_name
        self.etcd_addrs = etcd_addrs
        self.etcd_base_dir_l = etcd_base_dir_l
        self.etcd_base_dir_f = etcd_base_dir_f
        if gfile.Exists(data_source_l.data_block_dir):
            gfile.DeleteRecursively(data_source_l.data_block_dir)
        if gfile.Exists(data_source_l.example_dumped_dir):
            gfile.DeleteRecursively(data_source_l.example_dumped_dir)
        if gfile.Exists(data_source_l.raw_data_dir):
            gfile.DeleteRecursively(data_source_l.raw_data_dir)
        if gfile.Exists(data_source_f.data_block_dir):
            gfile.DeleteRecursively(data_source_f.data_block_dir)
        if gfile.Exists(data_source_f.example_dumped_dir):
            gfile.DeleteRecursively(data_source_f.example_dumped_dir)
        if gfile.Exists(data_source_f.raw_data_dir):
            gfile.DeleteRecursively(data_source_f.raw_data_dir)

        self.total_index = 1 << 13