Example #1
0
class DummyEventHandler:
    def __init__(self, batching=False):

        self.batcher = Batcher(timeout=1) if batching else None
        self.dispatcher = Dispatcher(
            ensemble=None,
            evaluator_callback=lambda x: None,
            batcher=self.batcher,
        )
        self.dispatcher._LOOKUP_MAP.clear()
        self.mock_all = Mock()
        self.mock_step = Mock()
        self.mock_none = Mock()

        self.dispatcher.register_event_handler(
            ids.EVGROUP_FM_ALL, self.all, batching=batching
        )
        self.dispatcher.register_event_handler(
            ids.EVGROUP_FM_STEP, self.step, batching=batching
        )

    async def join(self):
        await self.batcher.join()

    async def all(self, event):
        self.mock_all(event)

    async def step(self, event):
        self.mock_step(event)
Example #2
0
    def __init__(self, ensemble, config, iter_, ee_id: str = "0"):
        # Without information on the iteration, the events emitted from the
        # evaluator are ambiguous. In the future, an experiment authority* will
        # "own" the evaluators and can add iteration information to events they
        # emit. In the mean time, it is added here.
        # * https://github.com/equinor/ert/issues/1250
        self._iter = iter_
        self._ee_id = ee_id
        self._config = config
        self._ensemble = ensemble

        self._loop = asyncio.new_event_loop()
        self._done = self._loop.create_future()

        self._clients: Set[WebSocketServerProtocol] = set()
        self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue(
            loop=self._loop)
        self._batcher = Batcher(timeout=2, loop=self._loop)
        self._dispatcher = Dispatcher(
            ensemble=self._ensemble,
            evaluator_callback=self.dispatcher_callback,
            batcher=self._batcher,
        )

        self._result = None

        self._ws_thread = threading.Thread(name="ert_ee_run_server",
                                           target=self._run_server,
                                           args=(self._loop, ))
Example #3
0
    def __init__(self, batching=False):

        self.batcher = Batcher(timeout=1) if batching else None
        self.dispatcher = Dispatcher(
            ensemble=None,
            evaluator_callback=lambda x: None,
            batcher=self.batcher,
        )
        self.dispatcher._LOOKUP_MAP.clear()
        self.mock_all = Mock()
        self.mock_step = Mock()
        self.mock_none = Mock()

        self.dispatcher.register_event_handler(
            ids.EVGROUP_FM_ALL, self.all, batching=batching
        )
        self.dispatcher.register_event_handler(
            ids.EVGROUP_FM_STEP, self.step, batching=batching
        )
Example #4
0
class DummyEventHandler:
    dispatch = Dispatcher()

    def __init__(self):
        self.mock_all = Mock()
        self.mock_step = Mock()
        self.mock_none = Mock()

    @dispatch.register_event_handler(ids.EVGROUP_FM_ALL)
    async def all(self, event):
        self.mock_all(event)

    @dispatch.register_event_handler(ids.EVGROUP_FM_STEP)
    async def step(self, event):
        self.mock_step(event)
Example #5
0
class EnsembleEvaluator:
    _dispatch = Dispatcher()

    def __init__(self, ensemble, config, ee_id=0):
        self._ee_id = ee_id

        self._config = config
        self._ensemble = ensemble

        self._loop = asyncio.new_event_loop()
        self._ws_thread = threading.Thread(name="ert_ee_run_server",
                                           target=self._run_server,
                                           args=(self._loop, ))
        self._done = self._loop.create_future()

        self._clients = set()
        self._dispatchers_connected = asyncio.Queue(loop=self._loop)

        self._snapshot = self.create_snapshot(ensemble)
        self._event_index = 1

    @staticmethod
    def create_snapshot(ensemble):
        reals = {}
        for real in ensemble.get_active_reals():
            reals[str(real.get_iens())] = _Realization(
                active=True,
                start_time=None,
                end_time=None,
                status="Waiting",
            )
            for stage in real.get_stages():
                reals[str(real.get_iens())].stages[str(
                    stage.get_id())] = _Stage(
                        status="Unknown",
                        start_time=None,
                        end_time=None,
                    )
                for step in stage.get_steps():
                    reals[str(real.get_iens())].stages[str(
                        stage.get_id())].steps[str(step.get_id())] = _Step(
                            status="Unknown", start_time=None, end_time=None)
                    for job in step.get_jobs():
                        reals[str(real.get_iens())].stages[str(
                            stage.get_id())].steps[str(
                                step.get_id())].jobs[str(job.get_id())] = _Job(
                                    status="Pending",
                                    data={},
                                    start_time=None,
                                    end_time=None,
                                    name=job.get_name(),
                                )
        top = _SnapshotDict(
            reals=reals,
            status="Unknown",
            forward_model=_ForwardModel(step_definitions={}),
            metadata=ensemble.get_metadata(),
        )

        return Snapshot(top.dict())

    @_dispatch.register_event_handler(identifiers.EVGROUP_FM_ALL)
    async def _fm_handler(self, event):
        snapshot_mutate_event = PartialSnapshot(
            self._snapshot).from_cloudevent(event)
        await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_STOPPED)
    async def _ensemble_stopped_handler(self, event):
        if self._snapshot.get_status() != "Failure":
            snapshot_mutate_event = PartialSnapshot(
                self._snapshot).from_cloudevent(event)
            await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_STARTED)
    async def _ensemble_started_handler(self, event):
        if self._snapshot.get_status() != "Failure":
            snapshot_mutate_event = PartialSnapshot(
                self._snapshot).from_cloudevent(event)
            await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_CANCELLED)
    async def _ensemble_cancelled_handler(self, event):
        if self._snapshot.get_status() != "Failure":
            snapshot_mutate_event = PartialSnapshot(
                self._snapshot).from_cloudevent(event)
            await self._send_snapshot_update(snapshot_mutate_event)
            self._stop()

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_FAILED)
    async def _ensemble_failed_handler(self, event):
        if self._snapshot.get_status() not in ["Stopped", "Cancelled"]:
            snapshot_mutate_event = PartialSnapshot(
                self._snapshot).from_cloudevent(event)
            await self._send_snapshot_update(snapshot_mutate_event)

    async def _send_snapshot_update(self, snapshot_mutate_event):
        self._snapshot.merge_event(snapshot_mutate_event)
        out_cloudevent = CloudEvent(
            {
                "type": identifiers.EVTYPE_EE_SNAPSHOT_UPDATE,
                "source": f"/ert/ee/{self._ee_id}",
                "id": self.event_index(),
            },
            snapshot_mutate_event.to_dict(),
        )
        out_msg = to_json(out_cloudevent).decode()
        if out_msg and self._clients:
            await asyncio.wait(
                [client.send(out_msg) for client in self._clients])

    @staticmethod
    def create_snapshot_msg(ee_id, snapshot, event_index):
        data = snapshot.to_dict()
        out_cloudevent = CloudEvent(
            {
                "type": identifiers.EVTYPE_EE_SNAPSHOT,
                "source": f"/ert/ee/{ee_id}",
                "id": event_index,
            },
            data,
        )
        return to_json(out_cloudevent).decode()

    @contextmanager
    def store_client(self, websocket):
        self._clients.add(websocket)
        yield
        self._clients.remove(websocket)

    async def handle_client(self, websocket, path):
        with self.store_client(websocket):
            message = self.create_snapshot_msg(self._ee_id, self._snapshot,
                                               self.event_index())
            await websocket.send(message)

            async for message in websocket:
                client_event = from_json(message)
                logger.debug(f"got message from client: {client_event}")
                if client_event["type"] == identifiers.EVTYPE_EE_USER_CANCEL:
                    logger.debug(
                        f"Client {websocket.remote_address} asked to cancel.")
                    if self._ensemble.is_cancellable():
                        # The evaluator will stop after the ensemble has
                        # indicated it has been cancelled.
                        self._ensemble.cancel()
                    else:
                        self._stop()

                if client_event["type"] == identifiers.EVTYPE_EE_USER_DONE:
                    logger.debug(
                        f"Client {websocket.remote_address} signalled done.")
                    self._stop()

    @asynccontextmanager
    async def count_dispatcher(self):
        await self._dispatchers_connected.put(None)
        yield
        await self._dispatchers_connected.get()
        self._dispatchers_connected.task_done()

    async def handle_dispatch(self, websocket, path):
        async with self.count_dispatcher():
            async for msg in websocket:
                event = from_json(msg)
                await self._dispatch.handle_event(self, event)
                if event["type"] in [
                        identifiers.EVTYPE_ENSEMBLE_STOPPED,
                        identifiers.EVTYPE_ENSEMBLE_FAILED,
                ]:
                    return

    async def connection_handler(self, websocket, path):
        elements = path.split("/")
        if elements[1] == "client":
            await self.handle_client(websocket, path)
        elif elements[1] == "dispatch":
            await self.handle_dispatch(websocket, path)

    async def evaluator_server(self, done):
        async with websockets.serve(
                self.connection_handler,
                sock=self._config.get_socket(),
                max_queue=500,
                max_size=2**26,
        ):
            await done
            logger.debug("Got done signal.")
            # give NFS adaptors and Queue adaptors some time to read/send last events
            try:
                await asyncio.wait_for(self._dispatchers_connected.join(),
                                       timeout=10)
            except asyncio.TimeoutError:
                pass
            message = self.terminate_message()
            if self._clients:
                await asyncio.wait(
                    [client.send(message) for client in self._clients])
            logger.debug("Sent terminated to clients.")

        logger.debug("Async server exiting.")

    def _run_server(self, loop):
        loop.run_until_complete(self.evaluator_server(self._done))
        logger.debug("Server thread exiting.")

    def terminate_message(self):
        out_cloudevent = CloudEvent({
            "type": identifiers.EVTYPE_EE_TERMINATED,
            "source": f"/ert/ee/{self._ee_id}",
            "id": self.event_index(),
        })
        message = to_json(out_cloudevent).decode()
        return message

    def event_index(self):
        index = self._event_index
        self._event_index += 1
        return index

    def run(self):
        self._ws_thread.start()
        self._ensemble.evaluate(self._config, self._ee_id)
        return ee_monitor.create(self._config.host, self._config.port)

    def _stop(self):
        if not self._done.done():
            self._done.set_result(None)

    def stop(self):
        self._loop.call_soon_threadsafe(self._stop)
        self._ws_thread.join()

    def get_successful_realizations(self):
        return self._snapshot.get_successful_realizations()

    def run_and_get_successful_realizations(self):
        mon = self.run()
        for _ in mon.track():
            pass
        return self.get_successful_realizations()
Example #6
0
class EnsembleEvaluator:
    _dispatch = Dispatcher()

    def __init__(self, ensemble, config, iter_, ee_id: str = "0"):
        # Without information on the iteration, the events emitted from the
        # evaluator are ambiguous. In the future, an experiment authority* will
        # "own" the evaluators and can add iteration information to events they
        # emit. In the mean time, it is added here.
        # * https://github.com/equinor/ert/issues/1250
        self._iter = iter_
        self._ee_id = ee_id

        self._config = config
        self._ensemble = ensemble

        self._loop = asyncio.new_event_loop()
        self._ws_thread = threading.Thread(
            name="ert_ee_run_server", target=self._run_server, args=(self._loop,)
        )
        self._done = self._loop.create_future()

        self._clients: Set[WebSocketServerProtocol] = set()
        self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue(
            loop=self._loop
        )

        self._snapshot = self.create_snapshot(ensemble)
        self._event_index = 1
        self._result = None

    @staticmethod
    def create_snapshot(ensemble):
        reals = {}
        for real in ensemble.get_active_reals():
            reals[str(real.get_iens())] = Realization(
                active=True,
                status=REALIZATION_STATE_WAITING,
            )
            for step in real.get_steps():
                reals[str(real.get_iens())].steps[str(step.get_id())] = Step(
                    status=STEP_STATE_UNKNOWN
                )
                for job in step.get_jobs():
                    reals[str(real.get_iens())].steps[str(step.get_id())].jobs[
                        str(job.get_id())
                    ] = Job(
                        status=JOB_STATE_START,
                        data={},
                        name=job.get_name(),
                    )
        top = SnapshotDict(
            reals=reals,
            status=ENSEMBLE_STATE_STARTED,
            metadata=ensemble.get_metadata(),
        )

        return Snapshot(top.dict())

    @_dispatch.register_event_handler(identifiers.EVGROUP_FM_ALL)
    async def _fm_handler(self, event):
        snapshot_mutate_event = PartialSnapshot(self._snapshot).from_cloudevent(event)
        await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_STOPPED)
    async def _ensemble_stopped_handler(self, event):
        self._result = event.data
        if self._snapshot.get_status() != ENSEMBLE_STATE_FAILED:
            snapshot_mutate_event = PartialSnapshot(self._snapshot).from_cloudevent(
                event
            )
            await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_STARTED)
    async def _ensemble_started_handler(self, event):
        if self._snapshot.get_status() != ENSEMBLE_STATE_FAILED:
            snapshot_mutate_event = PartialSnapshot(self._snapshot).from_cloudevent(
                event
            )
            await self._send_snapshot_update(snapshot_mutate_event)

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_CANCELLED)
    async def _ensemble_cancelled_handler(self, event):
        if self._snapshot.get_status() != ENSEMBLE_STATE_FAILED:
            snapshot_mutate_event = PartialSnapshot(self._snapshot).from_cloudevent(
                event
            )
            await self._send_snapshot_update(snapshot_mutate_event)
            self._stop()

    @_dispatch.register_event_handler(identifiers.EVTYPE_ENSEMBLE_FAILED)
    async def _ensemble_failed_handler(self, event):
        if self._snapshot.get_status() not in [
            ENSEMBLE_STATE_STOPPED,
            ENSEMBLE_STATE_CANCELLED,
        ]:
            snapshot_mutate_event = PartialSnapshot(self._snapshot).from_cloudevent(
                event
            )
            await self._send_snapshot_update(snapshot_mutate_event)

    async def _send_snapshot_update(self, snapshot_mutate_event):
        self._snapshot.merge_event(snapshot_mutate_event)
        out_cloudevent = CloudEvent(
            {
                "type": identifiers.EVTYPE_EE_SNAPSHOT_UPDATE,
                "source": f"/ert/ee/{self._ee_id}",
                "id": self.event_index(),
            },
            snapshot_mutate_event.to_dict(),
        )
        out_cloudevent.data["iter"] = self._iter
        out_msg = to_json(
            out_cloudevent, data_marshaller=serialization.evaluator_marshaller
        ).decode()
        if out_msg and self._clients:
            await asyncio.wait([client.send(out_msg) for client in self._clients])

    @staticmethod
    def create_snapshot_msg(ee_id, iter_, snapshot, event_index):
        data = snapshot.to_dict()
        data["iter"] = iter_
        out_cloudevent = CloudEvent(
            {
                "type": identifiers.EVTYPE_EE_SNAPSHOT,
                "source": f"/ert/ee/{ee_id}",
                "id": event_index,
            },
            data,
        )
        return to_json(
            out_cloudevent, data_marshaller=serialization.evaluator_marshaller
        ).decode()

    @contextmanager
    def store_client(self, websocket):
        self._clients.add(websocket)
        yield
        try:
            self._clients.remove(websocket)
        except KeyError:
            logger.debug(
                f"Tried removing client {websocket.remote_address} twice. Likely the client was removed after sending a signal."
            )

    async def handle_client(self, websocket, path):
        with self.store_client(websocket):
            message = self.create_snapshot_msg(
                self._ee_id, self._iter, self._snapshot, self.event_index()
            )
            await websocket.send(message)

            async for message in websocket:
                client_event = from_json(
                    message, data_unmarshaller=serialization.evaluator_unmarshaller
                )
                logger.debug(f"got message from client: {client_event}")
                if client_event["type"] == identifiers.EVTYPE_EE_USER_CANCEL:
                    logger.debug(f"Client {websocket.remote_address} asked to cancel.")
                    if self._ensemble.is_cancellable():
                        # The evaluator will stop after the ensemble has
                        # indicated it has been cancelled.
                        self._ensemble.cancel()
                    else:
                        self._stop()

                if client_event["type"] == identifiers.EVTYPE_EE_USER_DONE:
                    logger.debug(f"Client {websocket.remote_address} signalled done.")
                    self._stop()

                # NOTE: due to how the monitor is implemented, a monitor that
                # signals will open a connection for each signal and
                # immediately exit after signalling. Consequently, it should be
                # harmless to remove the client from the pool.
                # If https://github.com/equinor/ert/issues/1538 is solved, then
                # this necessarily needs to change.
                self._clients.remove(websocket)

    @asynccontextmanager
    async def count_dispatcher(self):
        await self._dispatchers_connected.put(None)
        yield
        await self._dispatchers_connected.get()
        self._dispatchers_connected.task_done()

    async def handle_dispatch(self, websocket, path):
        async with self.count_dispatcher():
            async for msg in websocket:
                try:
                    event = from_json(
                        msg, data_unmarshaller=serialization.evaluator_unmarshaller
                    )
                except cloudevents.exceptions.DataUnmarshallerError:
                    event = from_json(msg, data_unmarshaller=lambda x: pickle.loads(x))
                if self._get_ee_id(event["source"]) != self._ee_id:
                    logger.info(
                        f"Got event from evaluator {self._get_ee_id(event['source'])} with source {event['source']}, ignoring since I am {self._ee_id}"
                    )
                    continue
                await self._dispatch.handle_event(self, event)
                if event["type"] in [
                    identifiers.EVTYPE_ENSEMBLE_STOPPED,
                    identifiers.EVTYPE_ENSEMBLE_FAILED,
                ]:
                    return

    async def handle_result(self, websocket, path):
        if self._result is None:
            event = CloudEvent(
                {
                    "type": identifiers.EVTYPE_EE_RESULT_NOT_READY,
                    "source": f"/ert/ee/{self._ee_id}",
                },
            )
        else:
            event = CloudEvent(
                {
                    "type": identifiers.EVTYPE_EE_RESULT,
                    "source": f"/ert/ee/{self._ee_id}",
                    "datacontenttype": "application/octet-stream",
                },
                cloudpickle.dumps(self._result),
            )
        await websocket.send(to_json(event))

    async def connection_handler(self, websocket, path):
        elements = path.split("/")
        if elements[1] == "client":
            await self.handle_client(websocket, path)
        elif elements[1] == "dispatch":
            await self.handle_dispatch(websocket, path)
        elif elements[1] == "result":
            await self.handle_result(websocket, path)
        else:
            logger.info(f"Connection attempt to unknown path: {path}.")

    async def process_request(self, path, request_headers):
        if request_headers.get("token") != self._config.token:
            return HTTPStatus.UNAUTHORIZED, {}, b""
        if path == "/healthcheck":
            return HTTPStatus.OK, {}, b""

    async def evaluator_server(self, done):
        async with websockets.serve(
            self.connection_handler,
            sock=self._config.get_socket(),
            ssl=self._config.get_server_ssl_context(),
            process_request=self.process_request,
            max_queue=None,
            max_size=2 ** 26,
        ):
            await done
            logger.debug("Got done signal.")
            # give NFS adaptors and Queue adaptors some time to read/send last events
            try:
                await asyncio.wait_for(self._dispatchers_connected.join(), timeout=10)
            except asyncio.TimeoutError:
                pass
            message = self.terminate_message()
            if self._clients:
                await asyncio.wait([client.send(message) for client in self._clients])
            logger.debug("Sent terminated to clients.")

        logger.debug("Async server exiting.")

    def _run_server(self, loop):
        loop.run_until_complete(self.evaluator_server(self._done))
        logger.debug("Server thread exiting.")

    def terminate_message(self):
        out_cloudevent = CloudEvent(
            {
                "type": identifiers.EVTYPE_EE_TERMINATED,
                "source": f"/ert/ee/{self._ee_id}",
                "id": str(self.event_index()),
            }
        )
        message = to_json(
            out_cloudevent, data_marshaller=serialization.evaluator_marshaller
        ).decode()
        return message

    def event_index(self):
        index = self._event_index
        self._event_index += 1
        return index

    def run(self) -> ee_monitor._Monitor:
        self._ws_thread.start()
        self._ensemble.evaluate(self._config, self._ee_id)
        return ee_monitor.create(
            self._config.host,
            self._config.port,
            self._config.protocol,
            self._config.cert,
            self._config.token,
        )

    def _stop(self):
        if not self._done.done():
            self._done.set_result(None)

    def stop(self):
        self._loop.call_soon_threadsafe(self._stop)
        self._ws_thread.join()

    def get_successful_realizations(self):
        return self._snapshot.get_successful_realizations()

    def run_and_get_successful_realizations(self):
        with self.run() as mon:
            for _ in mon.track():
                pass
        return self.get_successful_realizations()

    @staticmethod
    def _get_ee_id(source) -> str:
        # the ee_id will be found at /ert/ee/ee_id/...
        return source.split("/")[3]