def test_run_grpc_watch_thread(): client = DagsterGrpcClient(port=8080) shutdown_event, watch_thread = create_grpc_watch_thread(client) watch_thread.start() shutdown_event.set() watch_thread.join()
def test_grpc_watch_thread_server_update(): port = find_free_port() called = {} def on_updated(): called["yup"] = True # Create initial server server_process = open_server_process(port=port, socket=None) try: # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 4 shutdown_event, watch_thread = create_grpc_watch_thread( client, on_updated=on_updated, watch_interval=watch_interval) watch_thread.start() time.sleep(watch_interval * 2) finally: interrupt_ipc_subprocess_pid(server_process.pid) assert not called # Create updated server server_process = open_server_process(port=port, socket=None) try: time.sleep(watch_interval * 2) finally: interrupt_ipc_subprocess_pid(server_process.pid) shutdown_event.set() watch_thread.join() assert called
def test_grpc_watch_thread_server_complex_cycle_2(): # Server goes down, comes back up as the same server three times, then goes away and comes # back as a new server port = find_free_port() fixed_server_id = "fixed_id" events = [] called = {} def on_disconnect(): events.append("on_disconnect") def on_reconnected(): events.append("on_reconnected") def on_updated(_): events.append("on_updated") def on_error(): called["on_error"] = True events.append("on_error") # Create initial server open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 # This is a faster watch interval than we would use in practice shutdown_event, watch_thread = create_grpc_watch_thread( client, on_disconnect=on_disconnect, on_reconnected=on_reconnected, on_updated=on_updated, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=3, ) watch_thread.start() time.sleep(watch_interval * 3) cycles = 3 for x in range(1, cycles + 1): # Simulate server restart three times with same server ID client.shutdown_server() wait_for_condition(lambda: events.count("on_disconnect") == x, watch_interval) open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) wait_for_condition(lambda: events.count("on_reconnected") == x, watch_interval) # Simulate server failure client.shutdown_server() # Wait for reconnect attempts to exhaust and on_error callback to be called wait_for_condition(lambda: called.get("on_error"), watch_interval) shutdown_event.set() watch_thread.join() assert events[-1] == "on_error"
def _start_watch_thread( self, origin: GrpcServerRepositoryLocationOrigin) -> None: location_name = origin.location_name check.invariant( location_name not in self._watch_thread_shutdown_events) client = origin.create_client() shutdown_event, watch_thread = create_grpc_watch_thread( location_name, client, on_updated=lambda location_name, new_server_id: self. _send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_UPDATED, location_name=location_name, message="Server has been updated.", server_id=new_server_id, )), on_error=lambda location_name: self. _send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_ERROR, location_name=location_name, message= "Unable to reconnect to server. You can reload the server once it is " "reachable again", )), ) self._watch_thread_shutdown_events[location_name] = shutdown_event self._watch_threads[location_name] = watch_thread watch_thread.start()
def __init__(self, origin): from dagster.grpc.client import DagsterGrpcClient from dagster.grpc.server_watcher import create_grpc_watch_thread self._origin = check.inst_param(origin, "origin", GrpcServerRepositoryLocationOrigin) port = self.origin.port socket = self.origin.socket host = self.origin.host self._watch_thread_shutdown_event = None self._watch_thread = None try: self.client = DagsterGrpcClient(port=port, socket=socket, host=host) list_repositories_response = sync_list_repositories_grpc( self.client) self.server_id = sync_get_server_id(self.client) self.repository_names = set( symbol.repository_name for symbol in list_repositories_response.repository_symbols) self._state_subscribers = [] self._watch_thread_shutdown_event, self._watch_thread = create_grpc_watch_thread( self.client, on_updated=lambda new_server_id: self. _send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_UPDATED, location_name=self.location_name, message="Server has been updated.", server_id=new_server_id, )), on_error=lambda: self._send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_ERROR, location_name=self.location_name, message= "Unable to reconnect to server. You can reload the server once it is " "reachable again", )), ) self._watch_thread.start() self.executable_path = list_repositories_response.executable_path self.repository_code_pointer_dict = ( list_repositories_response.repository_code_pointer_dict) self.container_image = self._reload_current_image() except: self.cleanup() raise
def test_grpc_watch_thread_server_error(): port = find_free_port() fixed_server_id = "fixed_id" called = {} def on_disconnect(): called["on_disconnect"] = True def on_error(): called["on_error"] = True def should_not_be_called(): raise Exception("This method should not be called") # Create initial server server_process = open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 max_reconnect_attempts = 3 shutdown_event, watch_thread = create_grpc_watch_thread( client, on_disconnect=on_disconnect, on_reconnected=should_not_be_called, on_updated=should_not_be_called, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=max_reconnect_attempts, ) watch_thread.start() # Wait three seconds, simulate restart failure time.sleep(watch_interval * 3) interrupt_ipc_subprocess_pid(server_process.pid) # Wait for reconnect attempts to exhaust and on_error callback to be called start_time = time.time() while not called.get("on_error"): if time.time() - start_time > 30: break time.sleep(1) shutdown_event.set() watch_thread.join() assert called["on_disconnect"] assert called["on_error"]
def test_grpc_watch_thread_server_reconnect(): port = find_free_port() fixed_server_id = "fixed_id" called = {} def on_disconnect(location_name): assert location_name == "test_location" called["on_disconnect"] = True def on_reconnected(location_name): assert location_name == "test_location" called["on_reconnected"] = True def should_not_be_called(*args, **kwargs): raise Exception("This method should not be called") # Create initial server server_process = open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 shutdown_event, watch_thread = create_grpc_watch_thread( "test_location", client, on_disconnect=on_disconnect, on_reconnected=on_reconnected, on_updated=should_not_be_called, on_error=should_not_be_called, watch_interval=watch_interval, ) watch_thread.start() time.sleep(watch_interval * 3) # Wait three seconds, simulate restart server, wait three seconds interrupt_ipc_subprocess_pid(server_process.pid) wait_for_condition(lambda: called.get("on_disconnect"), watch_interval) server_process = open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) wait_for_condition(lambda: called.get("on_reconnected"), watch_interval) shutdown_event.set() watch_thread.join()
def test_run_grpc_watch_without_server(): # Starting a thread for a server that never existed should immediately error out client = DagsterGrpcClient(port=8080) watch_interval = 1 max_reconnect_attempts = 1 called = {} def on_disconnect(): called["on_disconnect"] = True def on_error(): called["on_error"] = True def should_not_be_called(): raise Exception("This method should not be called") shutdown_event, watch_thread = create_grpc_watch_thread( client, on_disconnect=on_disconnect, on_reconnected=should_not_be_called, on_updated=should_not_be_called, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=max_reconnect_attempts, ) watch_thread.start() time.sleep(watch_interval * 3) # Wait for reconnect attempts to exhaust and on_error callback to be called start_time = time.time() while not called.get("on_error"): if time.time() - start_time > 30: break time.sleep(1) shutdown_event.set() watch_thread.join() assert called["on_disconnect"] assert called["on_error"]
def test_run_grpc_watch_without_server(): # Starting a thread for a server that never existed should immediately error out client = DagsterGrpcClient(port=8080) watch_interval = 1 max_reconnect_attempts = 1 called = {} def on_disconnect(location_name): assert location_name == "test_location" called["on_disconnect"] = True def on_error(location_name): assert location_name == "test_location" called["on_error"] = True def should_not_be_called(*args, **kwargs): raise Exception("This method should not be called") shutdown_event, watch_thread = create_grpc_watch_thread( "test_location", client, on_disconnect=on_disconnect, on_reconnected=should_not_be_called, on_updated=should_not_be_called, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=max_reconnect_attempts, ) watch_thread.start() time.sleep(watch_interval * 3) # Wait for reconnect attempts to exhaust and on_error callback to be called wait_for_condition(lambda: called.get("on_error"), watch_interval) shutdown_event.set() watch_thread.join() assert called["on_disconnect"]
def test_grpc_watch_thread_server_update(): port = find_free_port() called = {} def on_updated(location_name, _): assert location_name == "test_location" called["yup"] = True # Create initial server server_process = open_server_process(port=port, socket=None) try: # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 shutdown_event, watch_thread = create_grpc_watch_thread( "test_location", client, on_updated=on_updated, watch_interval=watch_interval, ) watch_thread.start() time.sleep(watch_interval * 3) finally: interrupt_ipc_subprocess_pid(server_process.pid) assert not called # Create updated server server_process = open_server_process(port=port, socket=None) try: wait_for_condition(lambda: called, interval=watch_interval) finally: interrupt_ipc_subprocess_pid(server_process.pid) shutdown_event.set() watch_thread.join() assert called
def __init__( self, origin, host=None, port=None, socket=None, server_id=None, heartbeat=False, watch_server=True, ): from dagster.grpc.client import DagsterGrpcClient, client_heartbeat_thread from dagster.grpc.server_watcher import create_grpc_watch_thread self._origin = check.inst_param(origin, "origin", RepositoryLocationOrigin) if isinstance(self._origin, GrpcServerRepositoryLocationOrigin): self._port = self.origin.port self._socket = self.origin.socket self._host = self.origin.host self._use_ssl = bool(self.origin.use_ssl) else: self._port = check.opt_int_param(port, "port") self._socket = check.opt_str_param(socket, "socket") self._host = check.str_param(host, "host") self._use_ssl = False self._watch_thread_shutdown_event = None self._watch_thread = None self._heartbeat_shutdown_event = None self._heartbeat_thread = None self._heartbeat = check.bool_param(heartbeat, "heartbeat") self._watch_server = check.bool_param(watch_server, "watch_server") self.server_id = None self._external_repositories_data = None try: self.client = DagsterGrpcClient( port=self._port, socket=self._socket, host=self._host, use_ssl=self._use_ssl, ) list_repositories_response = sync_list_repositories_grpc( self.client) self.server_id = server_id if server_id else sync_get_server_id( self.client) self.repository_names = set( symbol.repository_name for symbol in list_repositories_response.repository_symbols) if self._heartbeat: self._heartbeat_shutdown_event = threading.Event() self._heartbeat_thread = threading.Thread( target=client_heartbeat_thread, args=( self.client, self._heartbeat_shutdown_event, ), name="grpc-client-heartbeat", ) self._heartbeat_thread.daemon = True self._heartbeat_thread.start() if self._watch_server: self._state_subscribers = [] self._watch_thread_shutdown_event, self._watch_thread = create_grpc_watch_thread( self.client, on_updated=lambda new_server_id: self. _send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_UPDATED, location_name=self.location_name, message="Server has been updated.", server_id=new_server_id, )), on_error=lambda: self._send_state_event_to_subscribers( LocationStateChangeEvent( LocationStateChangeEventType.LOCATION_ERROR, location_name=self.location_name, message= "Unable to reconnect to server. You can reload the server once it is " "reachable again", )), ) self._watch_thread.start() self.executable_path = list_repositories_response.executable_path self.repository_code_pointer_dict = ( list_repositories_response.repository_code_pointer_dict) self.container_image = self._reload_current_image() self._external_repositories_data = sync_get_streaming_external_repositories_data_grpc( self.client, self, ) except: self.cleanup() raise
def test_grpc_watch_thread_server_complex_cycle(): # Server goes down, comes back up as the same server three times, then goes away and comes # back as a new server port = find_free_port() fixed_server_id = "fixed_id" events = [] def on_disconnect(): events.append("on_disconnect") def on_reconnected(): events.append("on_reconnected") def on_updated(): events.append("on_updated") def on_error(): events.append("on_error") # Create initial server open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 # This is a faster watch interval than we would use in practice shutdown_event, watch_thread = create_grpc_watch_thread( client, on_disconnect=on_disconnect, on_reconnected=on_reconnected, on_updated=on_updated, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=5, ) watch_thread.start() time.sleep(watch_interval * 3) for _ in range(3): # Simulate server restart three times with same server ID client.shutdown_server() time.sleep(watch_interval * 3) open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) time.sleep(watch_interval * 3) # SImulate server update client.shutdown_server() time.sleep(watch_interval * 3) open_server_process(port=port, socket=None) time.sleep(watch_interval * 5) shutdown_event.set() watch_thread.join() assert "on_disconnect" in events assert "on_reconnected" in events assert events[-1] == "on_updated"
def test_grpc_watch_thread_server_error(): port = find_free_port() fixed_server_id = "fixed_id" called = {} def on_disconnect(location_name): assert location_name == "test_location" called["on_disconnect"] = True def on_error(location_name): assert location_name == "test_location" called["on_error"] = True def on_updated(location_name, new_server_id): assert location_name == "test_location" called["on_updated"] = new_server_id def should_not_be_called(*args, **kwargs): raise Exception("This method should not be called") # Create initial server server_process = open_server_process(port=port, socket=None, fixed_server_id=fixed_server_id) # Start watch thread client = DagsterGrpcClient(port=port) watch_interval = 1 max_reconnect_attempts = 3 shutdown_event, watch_thread = create_grpc_watch_thread( "test_location", client, on_disconnect=on_disconnect, on_reconnected=should_not_be_called, on_updated=on_updated, on_error=on_error, watch_interval=watch_interval, max_reconnect_attempts=max_reconnect_attempts, ) watch_thread.start() time.sleep(watch_interval * 3) # Simulate restart failure # Wait for reconnect attempts to exhaust and on_error callback to be called interrupt_ipc_subprocess_pid(server_process.pid) wait_for_condition(lambda: called.get("on_error"), watch_interval) assert called["on_disconnect"] assert called["on_error"] assert not called.get("on_updated") new_server_id = "new_server_id" server_process = open_server_process(port=port, socket=None, fixed_server_id=new_server_id) wait_for_condition(lambda: called.get("on_updated"), watch_interval) shutdown_event.set() watch_thread.join() assert called["on_updated"] == new_server_id