def __init__(self, role, listen_port, remote_address, app_id=None, rank=0, streaming_mode=True, compression=grpc.Compression.NoCompression): self._role = role self._listen_port = listen_port self._remote_address = remote_address if app_id is None: app_id = 'test_trainer' self._app_id = app_id self._rank = rank self._streaming_mode = streaming_mode self._compression = compression self._prefetch_handlers = [] self._data_block_handler_fn = None # Connection related self._connected = False self._terminated = False self._peer_terminated = False self._identifier = '%s-%s-%d-%d' % ( app_id, role, rank, int(time.time())) # Ensure unique per run self._peer_identifier = '' # data transmit self._condition = threading.Condition() self._current_iter_id = None self._next_iter_id = 0 self._received_data = {} # grpc client self._transmit_send_lock = threading.Lock() self._grpc_options = [ ('grpc.max_send_message_length', 2**31-1), ('grpc.max_receive_message_length', 2**31-1) ] self._channel = make_insecure_channel( remote_address, ChannelType.REMOTE, options=self._grpc_options, compression=self._compression) self._client = tws_grpc.TrainerWorkerServiceStub(self._channel) self._next_send_seq_num = 0 self._transmit_queue = queue.Queue() self._client_daemon = None self._client_daemon_shutdown_fn = None # server self._transmit_receive_lock = threading.Lock() self._next_receive_seq_num = 0 self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=self._grpc_options, compression=self._compression) tws_grpc.add_TrainerWorkerServiceServicer_to_server( Bridge.TrainerWorkerServicer(self), self._server) self._server.add_insecure_port('[::]:%d' % listen_port)
def __init__(self, role, listen_port, remote_address, app_id=None, worker_rank=0, stream_queue_size=1024, waiting_alert_timeout=10): self._role = role self._listen_address = "[::]:{}".format(listen_port) self._remote_address = remote_address if app_id is None: app_id = 'test_trainer' self._worker_rank = worker_rank self._token = "{}-{}".format(app_id, worker_rank) self._condition = threading.Condition() self._connected = False self._terminated = False self._peer_terminated = False self._current_iter_id = None self._next_iter_id = 0 self._iter_started_at = 0 self._peer_start_iter_id = None self._peer_commit_iter_id = None self._received_data = collections.defaultdict(dict) self._data_block_handler_fn = None self._waiting_alert_timeout = waiting_alert_timeout if self._waiting_alert_timeout < 1: self._waiting_alert_timeout = 1 # transmit stream queue self._stream_queue = collections.deque() self._stream_queue_size = stream_queue_size self._stream_thread = None self._stream_condition = threading.Condition() self._stream_terminated = False # channel self._channel = Channel(self._listen_address, self._remote_address, token=self._token, stats_client=_gctx.stats_client) self._channel.subscribe(self._channel_callback) # client & server self._client = tws2_grpc.TrainerWorkerServiceStub(self._channel) tws2_grpc.add_TrainerWorkerServiceServicer_to_server( Bridge.TrainerWorkerServicer(self), self._channel) # supervise self._supervise_interval = 5 self._supervise_iteration_timeout = 1200
def __init__(self, role, listen_port, remote_address, app_id='test_trainer', rank=0, streaming_mode=True): self._role = role self._listen_port = listen_port self._remote_address = remote_address self._app_id = app_id self._rank = rank self._streaming_mode = streaming_mode self._prefetch_handlers = [] self._data_block_handler_fn = None self._connected = False # data transmit self._condition = threading.Condition() self._current_iter_id = None self._next_iter_id = 0 self._received_data = {} # grpc client self._grpc_options = [ ('grpc.max_send_message_length', 2**31-1), ('grpc.max_receive_message_length', 2**31-1) ] channel = make_insecure_channel( remote_address, ChannelType.REMOTE, options=self._grpc_options) self._transmit_send_lock = threading.Lock() self._client = tws_grpc.TrainerWorkerServiceStub(channel) self._next_send_seq_num = 0 self._transmit_queue = queue.Queue() self._client_daemon = None self._client_daemon_shutdown_fn = None self._keepalive_daemon = None self._keepalive_daemon_shutdown_fn = None # server self._transmit_receive_lock = threading.Lock() self._next_receive_seq_num = 0 self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=self._grpc_options) tws_grpc.add_TrainerWorkerServiceServicer_to_server( Bridge.TrainerWorkerServicer(self), self._server) self._server.add_insecure_port('[::]:%d' % listen_port)
def make_ready_client(channel, stop_event=None): channel_ready = grpc.channel_ready_future(channel) wait_secs = 0.5 start_time = time.time() while (stop_event is None) or (not stop_event.is_set()): try: channel_ready.result(timeout=wait_secs) break except grpc.FutureTimeoutError: logging.warning('Channel has not been ready for %.2f seconds', time.time() - start_time) if wait_secs < 5.0: wait_secs *= 1.2 except Exception as e: # pylint: disable=broad-except logging.warning('Waiting channel ready: %s', repr(e)) return tws_grpc.TrainerWorkerServiceStub(channel)