コード例 #1
0
ファイル: bridge.py プロジェクト: nocmk2/fedlearner
    def __init__(self,
                 role,
                 listen_port,
                 remote_address,
                 app_id=None,
                 rank=0,
                 streaming_mode=True,
                 compression=grpc.Compression.NoCompression):
        self._role = role
        self._listen_port = listen_port
        self._remote_address = remote_address
        if app_id is None:
            app_id = 'test_trainer'
        self._app_id = app_id
        self._rank = rank
        self._streaming_mode = streaming_mode
        self._compression = compression

        self._prefetch_handlers = []
        self._data_block_handler_fn = None

        # Connection related
        self._connected = False
        self._terminated = False
        self._peer_terminated = False
        self._identifier = '%s-%s-%d-%d' % (
            app_id, role, rank, int(time.time())) # Ensure unique per run
        self._peer_identifier = ''

        # data transmit
        self._condition = threading.Condition()
        self._current_iter_id = None
        self._next_iter_id = 0
        self._received_data = {}

        # grpc client
        self._transmit_send_lock = threading.Lock()
        self._grpc_options = [
            ('grpc.max_send_message_length', 2**31-1),
            ('grpc.max_receive_message_length', 2**31-1)
        ]
        self._channel = make_insecure_channel(
            remote_address, ChannelType.REMOTE,
            options=self._grpc_options, compression=self._compression)
        self._client = tws_grpc.TrainerWorkerServiceStub(self._channel)
        self._next_send_seq_num = 0
        self._transmit_queue = queue.Queue()
        self._client_daemon = None
        self._client_daemon_shutdown_fn = None

        # server
        self._transmit_receive_lock = threading.Lock()
        self._next_receive_seq_num = 0
        self._server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10),
            options=self._grpc_options,
            compression=self._compression)
        tws_grpc.add_TrainerWorkerServiceServicer_to_server(
            Bridge.TrainerWorkerServicer(self), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
コード例 #2
0
    def __init__(self,
                 role,
                 listen_port,
                 remote_address,
                 app_id=None,
                 worker_rank=0,
                 stream_queue_size=1024,
                 waiting_alert_timeout=10):
        self._role = role
        self._listen_address = "[::]:{}".format(listen_port)
        self._remote_address = remote_address
        if app_id is None:
            app_id = 'test_trainer'
        self._worker_rank = worker_rank
        self._token = "{}-{}".format(app_id, worker_rank)

        self._condition = threading.Condition()
        self._connected = False
        self._terminated = False
        self._peer_terminated = False

        self._current_iter_id = None
        self._next_iter_id = 0
        self._iter_started_at = 0
        self._peer_start_iter_id = None
        self._peer_commit_iter_id = None

        self._received_data = collections.defaultdict(dict)
        self._data_block_handler_fn = None

        self._waiting_alert_timeout = waiting_alert_timeout
        if self._waiting_alert_timeout < 1:
            self._waiting_alert_timeout = 1

        # transmit stream queue
        self._stream_queue = collections.deque()
        self._stream_queue_size = stream_queue_size
        self._stream_thread = None
        self._stream_condition = threading.Condition()
        self._stream_terminated = False

        # channel
        self._channel = Channel(self._listen_address,
                                self._remote_address,
                                token=self._token,
                                stats_client=_gctx.stats_client)
        self._channel.subscribe(self._channel_callback)

        # client & server
        self._client = tws2_grpc.TrainerWorkerServiceStub(self._channel)
        tws2_grpc.add_TrainerWorkerServiceServicer_to_server(
            Bridge.TrainerWorkerServicer(self), self._channel)

        # supervise
        self._supervise_interval = 5
        self._supervise_iteration_timeout = 1200
コード例 #3
0
ファイル: bridge.py プロジェクト: feiga/fedlearner
    def __init__(self,
                 role,
                 listen_port,
                 remote_address,
                 app_id='test_trainer',
                 rank=0,
                 streaming_mode=True):
        self._role = role
        self._listen_port = listen_port
        self._remote_address = remote_address
        self._app_id = app_id
        self._rank = rank
        self._streaming_mode = streaming_mode

        self._prefetch_handlers = []
        self._data_block_handler_fn = None
        self._connected = False

        # data transmit
        self._condition = threading.Condition()
        self._current_iter_id = None
        self._next_iter_id = 0
        self._received_data = {}

        # grpc client
        self._grpc_options = [
            ('grpc.max_send_message_length', 2**31-1),
            ('grpc.max_receive_message_length', 2**31-1)
        ]
        channel = make_insecure_channel(
            remote_address, ChannelType.REMOTE, options=self._grpc_options)
        self._transmit_send_lock = threading.Lock()
        self._client = tws_grpc.TrainerWorkerServiceStub(channel)
        self._next_send_seq_num = 0
        self._transmit_queue = queue.Queue()
        self._client_daemon = None
        self._client_daemon_shutdown_fn = None
        self._keepalive_daemon = None
        self._keepalive_daemon_shutdown_fn = None

        # server
        self._transmit_receive_lock = threading.Lock()
        self._next_receive_seq_num = 0
        self._server = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10),
            options=self._grpc_options)
        tws_grpc.add_TrainerWorkerServiceServicer_to_server(
            Bridge.TrainerWorkerServicer(self), self._server)
        self._server.add_insecure_port('[::]:%d' % listen_port)
コード例 #4
0
ファイル: bridge.py プロジェクト: saswat0/fedlearner
def make_ready_client(channel, stop_event=None):
    channel_ready = grpc.channel_ready_future(channel)
    wait_secs = 0.5
    start_time = time.time()
    while (stop_event is None) or (not stop_event.is_set()):
        try:
            channel_ready.result(timeout=wait_secs)
            break
        except grpc.FutureTimeoutError:
            logging.warning('Channel has not been ready for %.2f seconds',
                            time.time() - start_time)
            if wait_secs < 5.0:
                wait_secs *= 1.2
        except Exception as e:  # pylint: disable=broad-except
            logging.warning('Waiting channel ready: %s', repr(e))
    return tws_grpc.TrainerWorkerServiceStub(channel)