def test_process_server_registry(): origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin=LoadableTargetOrigin( executable_path=sys.executable, attribute="repo", python_file=file_relative_path(__file__, "test_grpc_server_registry.py"), ), ) with ProcessGrpcServerRegistry(reload_interval=5, heartbeat_ttl=10, startup_timeout=5) as registry: endpoint_one = registry.get_grpc_endpoint(origin) endpoint_two = registry.get_grpc_endpoint(origin) assert endpoint_two == endpoint_one assert _can_connect(origin, endpoint_one) assert _can_connect(origin, endpoint_two) start_time = time.time() while True: # Registry should return a new server endpoint after 5 seconds endpoint_three = registry.get_grpc_endpoint(origin) if endpoint_three.server_id != endpoint_one.server_id: break if time.time() - start_time > 15: raise Exception("Server ID never changed") time.sleep(1) assert _can_connect(origin, endpoint_three) start_time = time.time() while True: # Server at endpoint_one should eventually die due to heartbeat failure if not _can_connect(origin, endpoint_one): break if time.time() - start_time > 30: raise Exception( "Old Server never died after process manager released it") time.sleep(1) # Make one more fresh process, then leave the context so that it will be cleaned up while True: endpoint_four = registry.get_grpc_endpoint(origin) if endpoint_four.server_id != endpoint_three.server_id: assert _can_connect(origin, endpoint_four) break registry.wait_for_processes() assert not _can_connect(origin, endpoint_three) assert not _can_connect(origin, endpoint_four)
def test_registry_multithreading(): origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin=LoadableTargetOrigin( executable_path=sys.executable, attribute="repo", python_file=file_relative_path(__file__, "test_grpc_server_registry.py"), ), ) with ProcessGrpcServerRegistry(reload_interval=300, heartbeat_ttl=600) as registry: endpoint = registry.get_grpc_endpoint(origin) threads = [] success_events = [] for _index in range(5): event = threading.Event() thread = threading.Thread(target=_registry_thread, args=(origin, registry, endpoint, event)) threads.append(thread) success_events.append(event) thread.start() for thread in threads: thread.join() for event in success_events: assert event.is_set() assert _can_connect(origin, endpoint) registry.wait_for_processes() assert not _can_connect(origin, endpoint)
def test_error_repo_in_registry(): error_origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin=LoadableTargetOrigin( executable_path=sys.executable, attribute="error_repo", python_file=file_relative_path(__file__, "error_repo.py"), ), ) with ProcessGrpcServerRegistry(reload_interval=5, heartbeat_ttl=10) as registry: # Repository with a loading error does not raise an exception endpoint = registry.get_grpc_endpoint(error_origin) # But using that endpoint to load a location results in an error with pytest.raises(DagsterUserCodeProcessError, match="object is not callable"): with GrpcServerRepositoryLocation( origin=error_origin, server_id=endpoint.server_id, port=endpoint.port, socket=endpoint.socket, host=endpoint.host, watch_server=False, ): pass # that error is idempotent with pytest.raises(DagsterUserCodeProcessError, match="object is not callable"): with GrpcServerRepositoryLocation( origin=error_origin, server_id=endpoint.server_id, port=endpoint.port, socket=endpoint.socket, host=endpoint.host, watch_server=False, ): pass
def _get_grpc_endpoint( self, repository_location_origin: ManagedGrpcPythonEnvRepositoryLocationOrigin ) -> GrpcServerEndpoint: origin_id = repository_location_origin.get_id() loadable_target_origin = self._get_loadable_target_origin( repository_location_origin) if not loadable_target_origin: raise Exception( f"No Python file/module information available for location {repository_location_origin.location_name}" ) if not origin_id in self._active_entries: refresh_server = True else: active_entry = self._active_entries[origin_id] refresh_server = loadable_target_origin != active_entry.loadable_target_origin server_process: Union[GrpcServerProcess, SerializableErrorInfo] new_server_id: Optional[str] if refresh_server: try: new_server_id = str(uuid.uuid4()) server_process = GrpcServerProcess( loadable_target_origin=loadable_target_origin, heartbeat=True, heartbeat_timeout=self._heartbeat_ttl, fixed_server_id=new_server_id, startup_timeout=self._startup_timeout, ) self._all_processes.append(server_process) except Exception: server_process = serializable_error_info_from_exc_info( sys.exc_info()) new_server_id = None self._active_entries[origin_id] = ProcessRegistryEntry( process_or_error=server_process, loadable_target_origin=loadable_target_origin, creation_timestamp=pendulum.now("UTC").timestamp(), server_id=new_server_id, ) active_entry = self._active_entries[origin_id] if isinstance(active_entry.process_or_error, SerializableErrorInfo): raise DagsterUserCodeProcessError( active_entry.process_or_error.to_string(), user_code_process_error_infos=[active_entry.process_or_error], ) return GrpcServerEndpoint( server_id=active_entry.server_id, host="localhost", port=active_entry.process_or_error.port, socket=active_entry.process_or_error.socket, )
def reload_grpc_endpoint( self, repository_location_origin: ManagedGrpcPythonEnvRepositoryLocationOrigin ) -> GrpcServerEndpoint: check.inst_param(repository_location_origin, "repository_location_origin", RepositoryLocationOrigin) with self._lock: origin_id = repository_location_origin.get_id() if origin_id in self._active_entries: # Free the map entry for this origin so that _get_grpc_endpoint will create # a new process del self._active_entries[origin_id] return self._get_grpc_endpoint(repository_location_origin)
def test_error_repo_in_registry(): error_origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin=LoadableTargetOrigin( executable_path=sys.executable, attribute="error_repo", python_file=file_relative_path(__file__, "error_repo.py"), ), ) with ProcessGrpcServerRegistry(reload_interval=5, heartbeat_ttl=10) as registry: # Repository with a loading error raises an exception with the reason why with pytest.raises(DagsterUserCodeProcessError, match="object is not callable"): registry.get_grpc_endpoint(error_origin) # the exception is idempotent with pytest.raises(DagsterUserCodeProcessError, match="object is not callable"): registry.get_grpc_endpoint(error_origin)
def create_from_repository_origin(repository_origin, instance): check.inst_param(repository_origin, "repository_origin", RepositoryOrigin) check.inst_param(instance, "instance", DagsterInstance) if isinstance(repository_origin, RepositoryGrpcServerOrigin): return RepositoryLocationHandle.create_from_repository_location_origin( GrpcServerRepositoryLocationOrigin( port=repository_origin.port, socket=repository_origin.socket, host=repository_origin.host, )) elif isinstance(repository_origin, RepositoryPythonOrigin): loadable_target_origin = repository_origin.loadable_target_origin repo_location_origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin) return RepositoryLocationHandle.create_from_repository_location_origin( repo_location_origin) else: raise DagsterInvariantViolationError( "Unexpected repository origin type")
def _create_python_env_location_origin( loadable_target_origin: LoadableTargetOrigin, location_name: Optional[str] ) -> ManagedGrpcPythonEnvRepositoryLocationOrigin: return ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin, location_name)
def test_process_server_registry(): origin = ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin=LoadableTargetOrigin( executable_path=sys.executable, attribute="repo", python_file=file_relative_path(__file__, "test_grpc_server_registry.py"), ), ) with ProcessGrpcServerRegistry(wait_for_processes_on_exit=True, cleanup_interval=5, heartbeat_interval=10) as registry: with RepositoryLocationHandleManager(registry) as handle_manager: endpoint_one = registry.get_grpc_endpoint(origin) handle_one = handle_manager.get_handle(origin) endpoint_two = registry.get_grpc_endpoint(origin) handle_two = handle_manager.get_handle(origin) assert endpoint_two == endpoint_one assert handle_two == handle_one assert _can_connect(origin, endpoint_one) assert _can_connect(origin, endpoint_two) start_time = time.time() while True: # Registry should return a new server endpoint after 5 seconds endpoint_three = registry.get_grpc_endpoint(origin) if endpoint_three.server_id != endpoint_one.server_id: # Handle manager now produces a new handle as well handle_three = handle_manager.get_handle(origin) assert handle_three != handle_one break if time.time() - start_time > 15: raise Exception("Server ID never changed") time.sleep(1) assert _can_connect(origin, endpoint_three) # Leave handle_manager context, all heartbeats stop start_time = time.time() while True: # Server at endpoint_one should eventually die due to heartbeat failure if not _can_connect(origin, endpoint_one): break if time.time() - start_time > 30: raise Exception( "Old Server never died after process manager released it") time.sleep(1) # Make one more fresh process, then leave the context so that it will be cleaned up while True: endpoint_four = registry.get_grpc_endpoint(origin) if endpoint_four.server_id != endpoint_three.server_id: assert _can_connect(origin, endpoint_four) break # Once we leave the ProcessGrpcServerRegistry context, all processes should be cleaned up # (if wait_for_processes_on_exit was set) assert not _can_connect(origin, endpoint_three) assert not _can_connect(origin, endpoint_four)
def _create_python_env_location_origin(loadable_target_origin, location_name): return ManagedGrpcPythonEnvRepositoryLocationOrigin( loadable_target_origin, location_name)