def test_predict_antinex_simple_error(self): exchange = "webapp.predict.requests" routing_key = "webapp.predict.requests" queue = "webapp.predict.requests" max_models = 1 prc = AntiNexProcessor( max_models=max_models) body = self.build_predict_antinex_request() self.assertEqual( body["ml_type"], "classification") message = MockMessage( exchange=exchange, routing_key=routing_key, queue=queue) self.assertEqual( message.state, "NOTRUN") self.assertEqual( message.get_exchange(), exchange) self.assertEqual( message.get_routing_key(), routing_key) self.assertEqual( message.get_queue(), queue) self.assertEqual( len(prc.models), 0) prc.handle_messages( body=body, message=message) self.assertEqual( message.state, "ACK") self.assertEqual( len(prc.models), 0)
def test_train_antinex_simple_model_cleanup(self): exchange = "webapp.train.requests" routing_key = "webapp.train.requests" queue = "webapp.train.requests" max_models = 1 prc = AntiNexProcessor(max_models=max_models) body = self.build_train_antinex_request() self.assertEqual(body["ml_type"], "classification") message = MockMessage(exchange=exchange, routing_key=routing_key, queue=queue) self.assertEqual(message.state, "NOTRUN") self.assertEqual(message.get_exchange(), exchange) self.assertEqual(message.get_routing_key(), routing_key) self.assertEqual(message.get_queue(), queue) self.assertEqual(len(prc.models), 0) prc.handle_messages(body=body, message=message) self.assertEqual(message.state, "ACK") self.assertEqual(len(prc.models), max_models) # now try to train a new one and test the cleanup body["label"] = "should-remove-the-first" self.assertEqual(len(prc.models), 1) prc.handle_messages(body=body, message=message) self.assertEqual(message.state, "ACK") self.assertEqual(len(prc.models), max_models) for midx, model_name in enumerate(prc.models): self.assertEqual(model_name, body["label"])
def __init__(self, name="", broker_url=ev("BROKER_URL", "redis://localhost:6379/6"), train_queue_name=ev("TRAIN_QUEUE", "webapp.train.requests"), predict_queue_name=ev("PREDICT_QUEUE", "webapp.predict.requests"), max_msgs=100, max_models=100): """__init__ :param name: worker name :param broker_url: connection string to broker :param train_queue_name: queue name for training requests :param predict_queue_name: queue name for predict requests :param max_msgs: num msgs to save for replay debugging (FIFO) :param max_models: num pre-trained models to keep in memory (FIFO) """ self.name = name log.info(("{} - INIT").format(self.name)) self.state = "INIT" self.broker_url = broker_url # Setup queues: self.train_queue_name = train_queue_name self.predict_queue_name = predict_queue_name self.queues = [self.train_queue_name, self.predict_queue_name] # Subscribers self.all_queues_sub = None # SSL Celery options dict self.ssl_options = {} # http://docs.celeryproject.org/en/latest/userguide/calling.html#calling-retry # noqa # allow publishes to retry for a time self.task_publish_retry_policy = { "interval_max": 1, "max_retries": 120, # None - forever "interval_start": 0.1, "interval_step": 0.2 } # Confirm publishes with Celery # https://github.com/celery/kombu/issues/572 self.transport_options = {"confirm_publish": True} self.conn_attrs = { "task_default_queue": "antinex.worker.control", "task_default_exchange": "antinex.worker.control", # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_hijack_root_logger "worker_hijack_root_logger": False, # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_prefetch_multiplier "worker_prefetch_multiplier": 1, # consume 1 message at a time # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_prefetch_multiplie "prefetch_count": 3, # noqa consume 1 message at a time per worker (3 workers) # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_heartbeat "broker_heartbeat": 240, # in seconds # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_connection_max_retries "broker_connection_max_retries": None, # None is forever # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-task_acks_late "task_acks_late": True, # noqa on consume do not send an immediate ack back "task_publish_retry_policy": self.task_publish_retry_policy } self.processor = AntiNexProcessor(name="{}.prc".format(self.name), max_msgs=max_msgs, max_models=max_models)
class AntiNexCore: """ AntiNex Celery Worker Core (core) This is a Celery Worker used to connect to a message broker (``BROKER_URL=redis://localhost:6379/6`` by default) and monitor messages to consume off the following queues: ``TRAIN_QUEUE`` - ``webapp.train.requests`` ``PREDICT_QUEUE`` - ``webapp.predict.requests`` The core trains and manages pre-trained deep neural networks (dnn) by training request ``label`` name. If you want to use an existing dnn, then just set the ``label`` name on a ``Prediction`` request and the core will do the rest. """ def __init__(self, name="", broker_url=ev("BROKER_URL", "redis://localhost:6379/6"), train_queue_name=ev("TRAIN_QUEUE", "webapp.train.requests"), predict_queue_name=ev("PREDICT_QUEUE", "webapp.predict.requests"), max_msgs=100, max_models=100): """__init__ :param name: worker name :param broker_url: connection string to broker :param train_queue_name: queue name for training requests :param predict_queue_name: queue name for predict requests :param max_msgs: num msgs to save for replay debugging (FIFO) :param max_models: num pre-trained models to keep in memory (FIFO) """ self.name = name log.info(("{} - INIT").format(self.name)) self.state = "INIT" self.broker_url = broker_url # Setup queues: self.train_queue_name = train_queue_name self.predict_queue_name = predict_queue_name self.queues = [self.train_queue_name, self.predict_queue_name] # Subscribers self.all_queues_sub = None # SSL Celery options dict self.ssl_options = {} # http://docs.celeryproject.org/en/latest/userguide/calling.html#calling-retry # noqa # allow publishes to retry for a time self.task_publish_retry_policy = { "interval_max": 1, "max_retries": 120, # None - forever "interval_start": 0.1, "interval_step": 0.2 } # Confirm publishes with Celery # https://github.com/celery/kombu/issues/572 self.transport_options = {"confirm_publish": True} self.conn_attrs = { "task_default_queue": "antinex.worker.control", "task_default_exchange": "antinex.worker.control", # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_hijack_root_logger "worker_hijack_root_logger": False, # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_prefetch_multiplier "worker_prefetch_multiplier": 1, # consume 1 message at a time # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-worker_prefetch_multiplie "prefetch_count": 3, # noqa consume 1 message at a time per worker (3 workers) # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_heartbeat "broker_heartbeat": 240, # in seconds # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-broker_connection_max_retries "broker_connection_max_retries": None, # None is forever # noqa http://docs.celeryproject.org/en/latest/userguide/configuration.html#std:setting-task_acks_late "task_acks_late": True, # noqa on consume do not send an immediate ack back "task_publish_retry_policy": self.task_publish_retry_policy } self.processor = AntiNexProcessor(name="{}.prc".format(self.name), max_msgs=max_msgs, max_models=max_models) # end of __init__ def start(self, app, ssl_options=None, task_retry_policy=None, transport_options=None, conn_attrs=None, queues=None, callback=None): """start :param app: Celery app :param ssl_options: ssl dictionary :param task_retry_policy: retry policy :param transport_options: transport config dict :param conn_attrs: connection dict :param queues: name of queues to consume messages :param callback: callback method when a message is consumed """ log.info(("{} - start").format(self.name)) use_queues = queues if not use_queues: use_queues = self.queues use_ssl_options = ssl_options if not use_ssl_options: use_ssl_options = self.ssl_options use_retry_policy = task_retry_policy if not use_retry_policy: use_retry_policy = self.task_publish_retry_policy use_transport_options = transport_options if not use_transport_options: use_transport_options = self.transport_options use_conn_attrs = conn_attrs if not use_conn_attrs: use_conn_attrs = self.conn_attrs use_callback = callback if not use_callback: use_callback = self.processor.handle_messages log.info(("{} - start - creating subscriber").format(self.name)) self.all_queues_sub = Subscriber(self.name, self.broker_url, app, use_ssl_options, **use_conn_attrs) log.info(("{} - start - activating consumer for " "queues={} callback={}").format(self.name, use_queues, use_callback.__name__)) self.all_queues_sub.consume( callback=use_callback, queues=use_queues, exchange=None, routing_key=None, prefetch_count=use_conn_attrs["prefetch_count"]) self.state = "ACTIVE" log.info(("{} - start - state={} done").format(self.name, self.state)) # end of start def show_diagnostics(self): """show_diagnostics""" self.processor.show_diagnostics() # end of show_diagnostics def shutdown(self): """shutdown""" log.info(("{} - shutting down - start").format(self.name)) self.state = "SHUTDOWN" self.show_diagnostics() self.processor.shutdown() log.info(("{} - shutting down - done").format(self.name))
def test_scaler_first_time_predict_with_rows(self): exchange = "webapp.predict.requests" routing_key = "webapp.predict.requests" queue = "webapp.predict.requests" max_models = 1 prc = AntiNexProcessor( max_models=max_models) num_rows_at_bottom = 7 predict_rows_body = self.build_predict_rows_from_dataset( num_rows_at_bottom=num_rows_at_bottom) name_of_model = predict_rows_body["label"].lower().strip() # make sure to remove the dataset arg to trigger # the predict from just the list of dictionary rows predict_rows_body.pop( "dataset", None) self.assertEqual( predict_rows_body["ml_type"], "classification") message = MockMessage( exchange=exchange, routing_key=routing_key, queue=queue) self.assertEqual( message.state, "NOTRUN") self.assertEqual( message.get_exchange(), exchange) self.assertEqual( message.get_routing_key(), routing_key) self.assertEqual( message.get_queue(), queue) self.assertEqual( len(prc.models), 0) prc.handle_messages( body=predict_rows_body, message=message) self.assertEqual( message.state, "ACK") self.assertEqual( len(prc.models), max_models) model_data = prc.models[name_of_model] print(len(model_data["data"]["sample_predictions"])) self.assertTrue( (len(model_data["data"]["sample_predictions"]) == num_rows_at_bottom)) num_rows_at_bottom = 4 predict_rows_body = self.build_predict_rows_from_dataset( num_rows_at_bottom=num_rows_at_bottom) # make sure to remove the dataset arg to trigger # the predict from just the list of dictionary rows predict_rows_body.pop( "dataset", None) self.assertEqual( len(prc.models), 1) prc.handle_messages( body=predict_rows_body, message=message) self.assertEqual( message.state, "ACK") self.assertEqual( len(prc.models), max_models) self.assertTrue( name_of_model in prc.models) model_data = prc.models[name_of_model] self.assertTrue( (len(model_data["data"]["sample_predictions"]) == num_rows_at_bottom))